diff --git a/homeassistant/components/iqvia/__init__.py b/homeassistant/components/iqvia/__init__.py index 3fabb88b041..049c23965b1 100644 --- a/homeassistant/components/iqvia/__init__.py +++ b/homeassistant/components/iqvia/__init__.py @@ -3,25 +3,18 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine -from datetime import timedelta -from functools import partial -from typing import Any from pyiqvia import Client -from pyiqvia.errors import IQVIAError from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import aiohttp_client -from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import ( CONF_ZIP_CODE, DOMAIN, - LOGGER, TYPE_ALLERGY_FORECAST, TYPE_ALLERGY_INDEX, TYPE_ALLERGY_OUTLOOK, @@ -30,9 +23,9 @@ from .const import ( TYPE_DISEASE_FORECAST, TYPE_DISEASE_INDEX, ) +from .coordinator import IqviaUpdateCoordinator DEFAULT_ATTRIBUTION = "Data provided by IQVIA™" -DEFAULT_SCAN_INTERVAL = timedelta(minutes=30) PLATFORMS = [Platform.SENSOR] @@ -52,15 +45,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # blocking) startup: client.disable_request_retries() - async def async_get_data_from_api( - api_coro: Callable[[], Coroutine[Any, Any, dict[str, Any]]], - ) -> dict[str, Any]: - """Get data from a particular API coroutine.""" - try: - return await api_coro() - except IQVIAError as err: - raise UpdateFailed from err - coordinators = {} init_data_update_tasks = [] @@ -73,13 +57,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: (TYPE_DISEASE_FORECAST, client.disease.extended), (TYPE_DISEASE_INDEX, client.disease.current), ): - coordinator = coordinators[sensor_type] = DataUpdateCoordinator( + coordinator = coordinators[sensor_type] = IqviaUpdateCoordinator( hass, - LOGGER, config_entry=entry, name=f"{entry.data[CONF_ZIP_CODE]} {sensor_type}", - update_interval=DEFAULT_SCAN_INTERVAL, - update_method=partial(async_get_data_from_api, api_coro), + update_method=api_coro, ) init_data_update_tasks.append(coordinator.async_refresh()) diff --git a/homeassistant/components/iqvia/coordinator.py b/homeassistant/components/iqvia/coordinator.py new file mode 100644 index 00000000000..420cbadbefa --- /dev/null +++ b/homeassistant/components/iqvia/coordinator.py @@ -0,0 +1,47 @@ +"""Support for IQVIA.""" + +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from datetime import timedelta +from typing import Any + +from pyiqvia.errors import IQVIAError + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .const import LOGGER + +DEFAULT_SCAN_INTERVAL = timedelta(minutes=30) + + +class IqviaUpdateCoordinator(DataUpdateCoordinator[dict[str, Any]]): + """Custom DataUpdateCoordinator for IQVIA.""" + + config_entry: ConfigEntry + + def __init__( + self, + hass: HomeAssistant, + config_entry: ConfigEntry, + name: str, + update_method: Callable[[], Coroutine[Any, Any, dict[str, Any]]], + ) -> None: + """Initialize the coordinator.""" + super().__init__( + hass, + LOGGER, + name=name, + config_entry=config_entry, + update_interval=DEFAULT_SCAN_INTERVAL, + ) + self._update_method = update_method + + async def _async_update_data(self) -> dict[str, Any]: + """Fetch data from the API.""" + try: + return await self._update_method() + except IQVIAError as err: + raise UpdateFailed from err