Add async_setup method to DataUpdateCoordinator (#116677)

* init

* Update homeassistant/helpers/update_coordinator.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

* fix typo, ruff

* consistency with rest, test

* pylint suppression

* ruff

* ruff

* switch to one test

* add last exc

* add tests for auth & Entry Errors

* move exceptions to correct test

* Update update_coordinator.py

Co-authored-by: G Johansson <goran.johansson@shiftit.se>

* test setup call

* simplify

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
Co-authored-by: G Johansson <goran.johansson@shiftit.se>
This commit is contained in:
Josef Zweck 2024-07-19 14:24:25 +02:00 committed by GitHub
parent de5b5f6d36
commit f006716173
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 9 deletions

View File

@ -71,6 +71,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
name: str, name: str,
update_interval: timedelta | None = None, update_interval: timedelta | None = None,
update_method: Callable[[], Awaitable[_DataT]] | None = None, update_method: Callable[[], Awaitable[_DataT]] | None = None,
setup_method: Callable[[], Awaitable[None]] | None = None,
request_refresh_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None, request_refresh_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None,
always_update: bool = True, always_update: bool = True,
) -> None: ) -> None:
@ -79,6 +80,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
self.logger = logger self.logger = logger
self.name = name self.name = name
self.update_method = update_method self.update_method = update_method
self.setup_method = setup_method
self._update_interval_seconds: float | None = None self._update_interval_seconds: float | None = None
self.update_interval = update_interval self.update_interval = update_interval
self._shutdown_requested = False self._shutdown_requested = False
@ -275,6 +277,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
fails. Additionally logging is handled by config entry setup fails. Additionally logging is handled by config entry setup
to ensure that multiple retries do not cause log spam. to ensure that multiple retries do not cause log spam.
""" """
if await self.__wrap_async_setup():
await self._async_refresh( await self._async_refresh(
log_failures=False, raise_on_auth_failed=True, raise_on_entry_error=True log_failures=False, raise_on_auth_failed=True, raise_on_entry_error=True
) )
@ -284,6 +287,44 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
ex.__cause__ = self.last_exception ex.__cause__ = self.last_exception
raise ex raise ex
async def __wrap_async_setup(self) -> bool:
"""Error handling for _async_setup."""
try:
await self._async_setup()
except (
TimeoutError,
requests.exceptions.Timeout,
aiohttp.ClientError,
requests.exceptions.RequestException,
urllib.error.URLError,
UpdateFailed,
) as err:
self.last_exception = err
except (ConfigEntryError, ConfigEntryAuthFailed) as err:
self.last_exception = err
self.last_update_success = False
raise
except Exception as err: # pylint: disable=broad-except
self.last_exception = err
self.logger.exception("Unexpected error fetching %s data", self.name)
else:
return True
self.last_update_success = False
return False
async def _async_setup(self) -> None:
"""Set up the coordinator.
Can be overwritten by integrations to load data or resources
only once during the first refresh.
"""
if self.setup_method is None:
return None
return await self.setup_method()
async def async_refresh(self) -> None: async def async_refresh(self) -> None:
"""Refresh data and log errors.""" """Refresh data and log errors."""
await self._async_refresh(log_failures=True) await self._async_refresh(log_failures=True)
@ -393,7 +434,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
self.logger.debug( self.logger.debug(
"Finished fetching %s data in %.3f seconds (success: %s)", "Finished fetching %s data in %.3f seconds (success: %s)",
self.name, self.name,
monotonic() - start, monotonic() - start, # pylint: disable=possibly-used-before-assignment
self.last_update_success, self.last_update_success,
) )
if not auth_failed and self._listeners and not self.hass.is_stopping: if not auth_failed and self._listeners and not self.hass.is_stopping:

View File

@ -13,7 +13,11 @@ import requests
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.core import CoreState, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
)
from homeassistant.helpers import update_coordinator from homeassistant.helpers import update_coordinator
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -525,11 +529,19 @@ async def test_stop_refresh_on_ha_stop(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"err_msg", "err_msg",
KNOWN_ERRORS, [
*KNOWN_ERRORS,
(Exception(), Exception, "Unknown exception"),
],
)
@pytest.mark.parametrize(
"method",
["update_method", "setup_method"],
) )
async def test_async_config_entry_first_refresh_failure( async def test_async_config_entry_first_refresh_failure(
err_msg: tuple[Exception, type[Exception], str], err_msg: tuple[Exception, type[Exception], str],
crd: update_coordinator.DataUpdateCoordinator[int], crd: update_coordinator.DataUpdateCoordinator[int],
method: str,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test async_config_entry_first_refresh raises ConfigEntryNotReady on failure. """Test async_config_entry_first_refresh raises ConfigEntryNotReady on failure.
@ -538,7 +550,7 @@ async def test_async_config_entry_first_refresh_failure(
will be caught by config_entries.async_setup which will log it with will be caught by config_entries.async_setup which will log it with
a decreasing level of logging once the first message is logged. a decreasing level of logging once the first message is logged.
""" """
crd.update_method = AsyncMock(side_effect=err_msg[0]) setattr(crd, method, AsyncMock(side_effect=err_msg[0]))
with pytest.raises(ConfigEntryNotReady): with pytest.raises(ConfigEntryNotReady):
await crd.async_config_entry_first_refresh() await crd.async_config_entry_first_refresh()
@ -548,13 +560,49 @@ async def test_async_config_entry_first_refresh_failure(
assert err_msg[2] not in caplog.text assert err_msg[2] not in caplog.text
@pytest.mark.parametrize(
"err_msg",
[
(ConfigEntryError(), ConfigEntryError, "Config entry error"),
(ConfigEntryAuthFailed(), ConfigEntryAuthFailed, "Config entry error"),
],
)
@pytest.mark.parametrize(
"method",
["update_method", "setup_method"],
)
async def test_async_config_entry_first_refresh_failure_passed_through(
err_msg: tuple[Exception, type[Exception], str],
crd: update_coordinator.DataUpdateCoordinator[int],
method: str,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test async_config_entry_first_refresh passes through ConfigEntryError & ConfigEntryAuthFailed.
Verify we do not log the exception since it
will be caught by config_entries.async_setup which will log it with
a decreasing level of logging once the first message is logged.
"""
setattr(crd, method, AsyncMock(side_effect=err_msg[0]))
with pytest.raises(err_msg[1]):
await crd.async_config_entry_first_refresh()
assert crd.last_update_success is False
assert isinstance(crd.last_exception, err_msg[1])
assert err_msg[2] not in caplog.text
async def test_async_config_entry_first_refresh_success( async def test_async_config_entry_first_refresh_success(
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test first refresh successfully.""" """Test first refresh successfully."""
crd.setup_method = AsyncMock()
await crd.async_config_entry_first_refresh() await crd.async_config_entry_first_refresh()
assert crd.last_update_success is True assert crd.last_update_success is True
crd.setup_method.assert_called_once()
async def test_not_schedule_refresh_if_system_option_disable_polling( async def test_not_schedule_refresh_if_system_option_disable_polling(