Add ability to pass the config entry explicitly in data update coordinators (#127980)

* Add ability to pass the config entry explicitely in data update coordinators

* Implement in accuweather

* Raise if config entry not set

* Move accuweather models

* Fix gogogate2

* Fix rainforest_raven
This commit is contained in:
epenet 2024-10-10 10:20:15 +02:00 committed by GitHub
parent 9b3f92e265
commit f504c27972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 115 additions and 41 deletions

View File

@ -2,13 +2,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import logging import logging
from accuweather import AccuWeather from accuweather import AccuWeather
from homeassistant.components.sensor import DOMAIN as SENSOR_PLATFORM from homeassistant.components.sensor import DOMAIN as SENSOR_PLATFORM
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_NAME, Platform from homeassistant.const import CONF_API_KEY, CONF_NAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
@ -16,7 +14,9 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import DOMAIN, UPDATE_INTERVAL_DAILY_FORECAST, UPDATE_INTERVAL_OBSERVATION from .const import DOMAIN, UPDATE_INTERVAL_DAILY_FORECAST, UPDATE_INTERVAL_OBSERVATION
from .coordinator import ( from .coordinator import (
AccuWeatherConfigEntry,
AccuWeatherDailyForecastDataUpdateCoordinator, AccuWeatherDailyForecastDataUpdateCoordinator,
AccuWeatherData,
AccuWeatherObservationDataUpdateCoordinator, AccuWeatherObservationDataUpdateCoordinator,
) )
@ -25,17 +25,6 @@ _LOGGER = logging.getLogger(__name__)
PLATFORMS = [Platform.SENSOR, Platform.WEATHER] PLATFORMS = [Platform.SENSOR, Platform.WEATHER]
@dataclass
class AccuWeatherData:
"""Data for AccuWeather integration."""
coordinator_observation: AccuWeatherObservationDataUpdateCoordinator
coordinator_daily_forecast: AccuWeatherDailyForecastDataUpdateCoordinator
type AccuWeatherConfigEntry = ConfigEntry[AccuWeatherData]
async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry) -> bool:
"""Set up AccuWeather as config entry.""" """Set up AccuWeather as config entry."""
api_key: str = entry.data[CONF_API_KEY] api_key: str = entry.data[CONF_API_KEY]
@ -50,6 +39,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry)
coordinator_observation = AccuWeatherObservationDataUpdateCoordinator( coordinator_observation = AccuWeatherObservationDataUpdateCoordinator(
hass, hass,
entry,
accuweather, accuweather,
name, name,
"observation", "observation",
@ -58,6 +48,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry)
coordinator_daily_forecast = AccuWeatherDailyForecastDataUpdateCoordinator( coordinator_daily_forecast = AccuWeatherDailyForecastDataUpdateCoordinator(
hass, hass,
entry,
accuweather, accuweather,
name, name,
"daily forecast", "daily forecast",

View File

@ -1,6 +1,9 @@
"""The AccuWeather coordinator.""" """The AccuWeather coordinator."""
from __future__ import annotations
from asyncio import timeout from asyncio import timeout
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -8,6 +11,7 @@ from typing import TYPE_CHECKING, Any
from accuweather import AccuWeather, ApiError, InvalidApiKeyError, RequestsExceededError from accuweather import AccuWeather, ApiError, InvalidApiKeyError, RequestsExceededError
from aiohttp.client_exceptions import ClientConnectorError from aiohttp.client_exceptions import ClientConnectorError
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.update_coordinator import ( from homeassistant.helpers.update_coordinator import (
@ -23,6 +27,17 @@ EXCEPTIONS = (ApiError, ClientConnectorError, InvalidApiKeyError, RequestsExceed
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclass
class AccuWeatherData:
"""Data for AccuWeather integration."""
coordinator_observation: AccuWeatherObservationDataUpdateCoordinator
coordinator_daily_forecast: AccuWeatherDailyForecastDataUpdateCoordinator
type AccuWeatherConfigEntry = ConfigEntry[AccuWeatherData]
class AccuWeatherObservationDataUpdateCoordinator( class AccuWeatherObservationDataUpdateCoordinator(
DataUpdateCoordinator[dict[str, Any]] DataUpdateCoordinator[dict[str, Any]]
): ):
@ -31,6 +46,7 @@ class AccuWeatherObservationDataUpdateCoordinator(
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
config_entry: AccuWeatherConfigEntry,
accuweather: AccuWeather, accuweather: AccuWeather,
name: str, name: str,
coordinator_type: str, coordinator_type: str,
@ -48,6 +64,7 @@ class AccuWeatherObservationDataUpdateCoordinator(
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
config_entry=config_entry,
name=f"{name} ({coordinator_type})", name=f"{name} ({coordinator_type})",
update_interval=update_interval, update_interval=update_interval,
) )
@ -73,6 +90,7 @@ class AccuWeatherDailyForecastDataUpdateCoordinator(
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
config_entry: AccuWeatherConfigEntry,
accuweather: AccuWeather, accuweather: AccuWeather,
name: str, name: str,
coordinator_type: str, coordinator_type: str,
@ -90,6 +108,7 @@ class AccuWeatherDailyForecastDataUpdateCoordinator(
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
config_entry=config_entry,
name=f"{name} ({coordinator_type})", name=f"{name} ({coordinator_type})",
update_interval=update_interval, update_interval=update_interval,
) )

View File

@ -8,7 +8,7 @@ from homeassistant.components.diagnostics import async_redact_data
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import AccuWeatherConfigEntry, AccuWeatherData from .coordinator import AccuWeatherConfigEntry, AccuWeatherData
TO_REDACT = {CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE} TO_REDACT = {CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE}

View File

@ -28,7 +28,6 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AccuWeatherConfigEntry
from .const import ( from .const import (
API_METRIC, API_METRIC,
ATTR_CATEGORY, ATTR_CATEGORY,
@ -41,6 +40,7 @@ from .const import (
MAX_FORECAST_DAYS, MAX_FORECAST_DAYS,
) )
from .coordinator import ( from .coordinator import (
AccuWeatherConfigEntry,
AccuWeatherDailyForecastDataUpdateCoordinator, AccuWeatherDailyForecastDataUpdateCoordinator,
AccuWeatherObservationDataUpdateCoordinator, AccuWeatherObservationDataUpdateCoordinator,
) )

View File

@ -9,8 +9,8 @@ from accuweather.const import ENDPOINT
from homeassistant.components import system_health from homeassistant.components import system_health
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from . import AccuWeatherConfigEntry
from .const import DOMAIN from .const import DOMAIN
from .coordinator import AccuWeatherConfigEntry
@callback @callback

View File

@ -33,7 +33,6 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utc_from_timestamp from homeassistant.util.dt import utc_from_timestamp
from . import AccuWeatherConfigEntry, AccuWeatherData
from .const import ( from .const import (
API_METRIC, API_METRIC,
ATTR_DIRECTION, ATTR_DIRECTION,
@ -43,7 +42,9 @@ from .const import (
CONDITION_MAP, CONDITION_MAP,
) )
from .coordinator import ( from .coordinator import (
AccuWeatherConfigEntry,
AccuWeatherDailyForecastDataUpdateCoordinator, AccuWeatherDailyForecastDataUpdateCoordinator,
AccuWeatherData,
AccuWeatherObservationDataUpdateCoordinator, AccuWeatherObservationDataUpdateCoordinator,
) )

View File

@ -75,6 +75,7 @@ class RAVEnDataCoordinator(DataUpdateCoordinator):
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
config_entry=entry,
name=DOMAIN, name=DOMAIN,
update_interval=timedelta(seconds=30), update_interval=timedelta(seconds=30),
) )

View File

@ -29,6 +29,7 @@ from homeassistant.util.dt import utcnow
from . import entity, event from . import entity, event
from .debounce import Debouncer from .debounce import Debouncer
from .typing import UNDEFINED, UndefinedType
REQUEST_REFRESH_DEFAULT_COOLDOWN = 10 REQUEST_REFRESH_DEFAULT_COOLDOWN = 10
REQUEST_REFRESH_DEFAULT_IMMEDIATE = True REQUEST_REFRESH_DEFAULT_IMMEDIATE = True
@ -68,6 +69,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
hass: HomeAssistant, hass: HomeAssistant,
logger: logging.Logger, logger: logging.Logger,
*, *,
config_entry: config_entries.ConfigEntry | None | UndefinedType = UNDEFINED,
name: str, name: str,
update_interval: timedelta | None = None, update_interval: timedelta | None = None,
update_method: Callable[[], Awaitable[_DataT]] | None = None, update_method: Callable[[], Awaitable[_DataT]] | None = None,
@ -84,7 +86,12 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
self._update_interval_seconds: float | None = None self._update_interval_seconds: float | None = None
self.update_interval = update_interval self.update_interval = update_interval
self._shutdown_requested = False self._shutdown_requested = False
self.config_entry = config_entries.current_entry.get() if config_entry is UNDEFINED:
self.config_entry = config_entries.current_entry.get()
# This should be deprecated once all core integrations are updated
# to pass in the config entry explicitly.
else:
self.config_entry = config_entry
self.always_update = always_update self.always_update = always_update
# It's None before the first successful update. # It's None before the first successful update.
@ -277,6 +284,10 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
fails. Additionally logging is handled by config entry setup fails. Additionally logging is handled by config entry setup
to ensure that multiple retries do not cause log spam. to ensure that multiple retries do not cause log spam.
""" """
if self.config_entry is None:
raise ValueError(
"This method is only supported for coordinators with a config entry"
)
if await self.__wrap_async_setup(): if await self.__wrap_async_setup():
await self._async_refresh( await self._async_refresh(
log_failures=False, raise_on_auth_failed=True, raise_on_entry_error=True log_failures=False, raise_on_auth_failed=True, raise_on_entry_error=True

View File

@ -3,11 +3,10 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from ismartgate import GogoGate2Api from ismartgate import GogoGate2Api
import pytest
from homeassistant.components.gogogate2 import DEVICE_TYPE_GOGOGATE2, async_setup_entry from homeassistant.components.gogogate2 import DEVICE_TYPE_GOGOGATE2
from homeassistant.components.gogogate2.const import DEVICE_TYPE_ISMARTGATE, DOMAIN from homeassistant.components.gogogate2.const import DEVICE_TYPE_ISMARTGATE, DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_USER, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE, CONF_DEVICE,
CONF_IP_ADDRESS, CONF_IP_ADDRESS,
@ -15,7 +14,6 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -97,6 +95,8 @@ async def test_api_failure_on_startup(hass: HomeAssistant) -> None:
"homeassistant.components.gogogate2.common.ISmartGateApi.async_info", "homeassistant.components.gogogate2.common.ISmartGateApi.async_info",
side_effect=TimeoutError, side_effect=TimeoutError,
), ),
pytest.raises(ConfigEntryNotReady),
): ):
await async_setup_entry(hass, config_entry) await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert config_entry.state is ConfigEntryState.SETUP_RETRY

View File

@ -57,7 +57,9 @@ KNOWN_ERRORS: list[tuple[Exception, type[Exception], str]] = [
def get_crd( def get_crd(
hass: HomeAssistant, update_interval: timedelta | None hass: HomeAssistant,
update_interval: timedelta | None,
config_entry: config_entries.ConfigEntry | None = None,
) -> update_coordinator.DataUpdateCoordinator[int]: ) -> update_coordinator.DataUpdateCoordinator[int]:
"""Make coordinator mocks.""" """Make coordinator mocks."""
calls = 0 calls = 0
@ -70,6 +72,7 @@ def get_crd(
return update_coordinator.DataUpdateCoordinator[int]( return update_coordinator.DataUpdateCoordinator[int](
hass, hass,
_LOGGER, _LOGGER,
config_entry=config_entry,
name="test", name="test",
update_method=refresh, update_method=refresh,
update_interval=update_interval, update_interval=update_interval,
@ -121,8 +124,7 @@ async def test_async_refresh(
async def test_shutdown( async def test_shutdown(
hass: HomeAssistant, hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int]
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: ) -> None:
"""Test async_shutdown for update coordinator.""" """Test async_shutdown for update coordinator."""
assert crd.data is None assert crd.data is None
@ -158,8 +160,7 @@ async def test_shutdown(
async def test_shutdown_on_entry_unload( async def test_shutdown_on_entry_unload(
hass: HomeAssistant, hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int]
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: ) -> None:
"""Test shutdown is requested on entry unload.""" """Test shutdown is requested on entry unload."""
entry = MockConfigEntry() entry = MockConfigEntry()
@ -191,8 +192,7 @@ async def test_shutdown_on_entry_unload(
async def test_shutdown_on_hass_stop( async def test_shutdown_on_hass_stop(
hass: HomeAssistant, hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int]
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: ) -> None:
"""Test shutdown can be shutdown on STOP event.""" """Test shutdown can be shutdown on STOP event."""
calls = 0 calls = 0
@ -539,8 +539,8 @@ async def test_stop_refresh_on_ha_stop(
["update_method", "setup_method"], ["update_method", "setup_method"],
) )
async def test_async_config_entry_first_refresh_failure( async def test_async_config_entry_first_refresh_failure(
hass: HomeAssistant,
err_msg: tuple[Exception, type[Exception], str], err_msg: tuple[Exception, type[Exception], str],
crd: update_coordinator.DataUpdateCoordinator[int],
method: str, method: str,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
@ -550,6 +550,8 @@ async def test_async_config_entry_first_refresh_failure(
will be caught by config_entries.async_setup which will log it with will be caught by config_entries.async_setup which will log it with
a decreasing level of logging once the first message is logged. a decreasing level of logging once the first message is logged.
""" """
entry = MockConfigEntry()
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
setattr(crd, method, AsyncMock(side_effect=err_msg[0])) setattr(crd, method, AsyncMock(side_effect=err_msg[0]))
with pytest.raises(ConfigEntryNotReady): with pytest.raises(ConfigEntryNotReady):
@ -572,8 +574,8 @@ async def test_async_config_entry_first_refresh_failure(
["update_method", "setup_method"], ["update_method", "setup_method"],
) )
async def test_async_config_entry_first_refresh_failure_passed_through( async def test_async_config_entry_first_refresh_failure_passed_through(
hass: HomeAssistant,
err_msg: tuple[Exception, type[Exception], str], err_msg: tuple[Exception, type[Exception], str],
crd: update_coordinator.DataUpdateCoordinator[int],
method: str, method: str,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
@ -583,6 +585,8 @@ async def test_async_config_entry_first_refresh_failure_passed_through(
will be caught by config_entries.async_setup which will log it with will be caught by config_entries.async_setup which will log it with
a decreasing level of logging once the first message is logged. a decreasing level of logging once the first message is logged.
""" """
entry = MockConfigEntry()
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
setattr(crd, method, AsyncMock(side_effect=err_msg[0])) setattr(crd, method, AsyncMock(side_effect=err_msg[0]))
with pytest.raises(err_msg[1]): with pytest.raises(err_msg[1]):
@ -593,11 +597,10 @@ async def test_async_config_entry_first_refresh_failure_passed_through(
assert err_msg[2] not in caplog.text assert err_msg[2] not in caplog.text
async def test_async_config_entry_first_refresh_success( async def test_async_config_entry_first_refresh_success(hass: HomeAssistant) -> None:
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture
) -> None:
"""Test first refresh successfully.""" """Test first refresh successfully."""
entry = MockConfigEntry()
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
crd.setup_method = AsyncMock() crd.setup_method = AsyncMock()
await crd.async_config_entry_first_refresh() await crd.async_config_entry_first_refresh()
@ -605,13 +608,26 @@ async def test_async_config_entry_first_refresh_success(
crd.setup_method.assert_called_once() crd.setup_method.assert_called_once()
async def test_async_config_entry_first_refresh_no_entry(hass: HomeAssistant) -> None:
"""Test first refresh successfully."""
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, None)
crd.setup_method = AsyncMock()
with pytest.raises(
ValueError,
match="This method is only supported for coordinators with a config entry",
):
await crd.async_config_entry_first_refresh()
assert crd.last_update_success is True
crd.setup_method.assert_not_called()
async def test_not_schedule_refresh_if_system_option_disable_polling( async def test_not_schedule_refresh_if_system_option_disable_polling(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test we do not schedule a refresh if disable polling in config entry.""" """Test we do not schedule a refresh if disable polling in config entry."""
entry = MockConfigEntry(pref_disable_polling=True) entry = MockConfigEntry(pref_disable_polling=True)
config_entries.current_entry.set(entry) crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL)
crd.async_add_listener(lambda: None) crd.async_add_listener(lambda: None)
assert crd._unsub_refresh is None assert crd._unsub_refresh is None
@ -651,7 +667,7 @@ async def test_async_set_update_error(
async def test_only_callback_on_change_when_always_update_is_false( async def test_only_callback_on_change_when_always_update_is_false(
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: ) -> None:
"""Test we do not callback listeners unless something has actually changed when always_update is false.""" """Test we do not callback listeners unless something has actually changed when always_update is false."""
update_callback = Mock() update_callback = Mock()
@ -721,7 +737,7 @@ async def test_only_callback_on_change_when_always_update_is_false(
async def test_always_callback_when_always_update_is_true( async def test_always_callback_when_always_update_is_true(
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: ) -> None:
"""Test we callback listeners even though the data is the same when always_update is True.""" """Test we callback listeners even though the data is the same when always_update is True."""
update_callback = Mock() update_callback = Mock()
@ -795,3 +811,38 @@ async def test_timestamp_date_update_coordinator(hass: HomeAssistant) -> None:
unsub() unsub()
await crd.async_refresh() await crd.async_refresh()
assert len(last_update_success_times) == 1 assert len(last_update_success_times) == 1
async def test_config_entry(hass: HomeAssistant) -> None:
"""Test behavior of coordinator.entry."""
entry = MockConfigEntry()
# Default without context should be None
crd = update_coordinator.DataUpdateCoordinator[int](hass, _LOGGER, name="test")
assert crd.config_entry is None
# Explicit None is OK
crd = update_coordinator.DataUpdateCoordinator[int](
hass, _LOGGER, name="test", config_entry=None
)
assert crd.config_entry is None
# Explicit entry is OK
crd = update_coordinator.DataUpdateCoordinator[int](
hass, _LOGGER, name="test", config_entry=entry
)
assert crd.config_entry is entry
# set ContextVar
config_entries.current_entry.set(entry)
# Default with ContextVar should match the ContextVar
crd = update_coordinator.DataUpdateCoordinator[int](hass, _LOGGER, name="test")
assert crd.config_entry is entry
# Explicit entry different from ContextVar not recommended, but should work
another_entry = MockConfigEntry()
crd = update_coordinator.DataUpdateCoordinator[int](
hass, _LOGGER, name="test", config_entry=another_entry
)
assert crd.config_entry is another_entry