From b4bac7705ed34e2ee535d56dd0a15fe9e13d3060 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 5 May 2024 05:09:57 -0500 Subject: [PATCH] Ensure all synology_dsm coordinators handle expired sessions (#116796) * Ensure all synology_dsm coordinators handle expired sessions * Ensure all synology_dsm coordinators handle expired sessions * Ensure all synology_dsm coordinators handle expired sessions * handle cancellation * add a debug log message --------- Co-authored-by: mib1185 --- .../components/synology_dsm/__init__.py | 6 ++- .../components/synology_dsm/common.py | 28 ++++++++++- .../components/synology_dsm/coordinator.py | 50 +++++++++++++------ 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/synology_dsm/__init__.py b/homeassistant/components/synology_dsm/__init__.py index 2748b27c93d..6598ed304f7 100644 --- a/homeassistant/components/synology_dsm/__init__.py +++ b/homeassistant/components/synology_dsm/__init__.py @@ -7,6 +7,7 @@ import logging from synology_dsm.api.surveillance_station import SynoSurveillanceStation from synology_dsm.api.surveillance_station.camera import SynoCamera +from synology_dsm.exceptions import SynologyDSMNotLoggedInException from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_MAC, CONF_VERIFY_SSL @@ -69,7 +70,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await api.async_setup() except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err: raise_config_entry_auth_error(err) - except SYNOLOGY_CONNECTION_EXCEPTIONS as err: + except (*SYNOLOGY_CONNECTION_EXCEPTIONS, SynologyDSMNotLoggedInException) as err: + # SynologyDSMNotLoggedInException may be raised even if the user is + # logged in because the session may have expired, and we need to retry + # the login later. if err.args[0] and isinstance(err.args[0], dict): details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN) else: diff --git a/homeassistant/components/synology_dsm/common.py b/homeassistant/components/synology_dsm/common.py index 04e8ae29ceb..c871dd7b705 100644 --- a/homeassistant/components/synology_dsm/common.py +++ b/homeassistant/components/synology_dsm/common.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from contextlib import suppress import logging @@ -82,6 +83,31 @@ class SynoApi: self._with_upgrade = True self._with_utilisation = True + self._login_future: asyncio.Future[None] | None = None + + async def async_login(self) -> None: + """Login to the Synology DSM API. + + This function will only login once if called multiple times + by multiple different callers. + + If a login is already in progress, the function will await the + login to complete before returning. + """ + if self._login_future: + return await self._login_future + + self._login_future = self._hass.loop.create_future() + try: + await self.dsm.login() + self._login_future.set_result(None) + except BaseException as err: + if not self._login_future.done(): + self._login_future.set_exception(err) + raise + finally: + self._login_future = None + async def async_setup(self) -> None: """Start interacting with the NAS.""" session = async_get_clientsession(self._hass, self._entry.data[CONF_VERIFY_SSL]) @@ -95,7 +121,7 @@ class SynoApi: timeout=self._entry.options.get(CONF_TIMEOUT) or 10, device_token=self._entry.data.get(CONF_DEVICE_TOKEN), ) - await self.dsm.login() + await self.async_login() # check if surveillance station is used self._with_surveillance_station = bool( diff --git a/homeassistant/components/synology_dsm/coordinator.py b/homeassistant/components/synology_dsm/coordinator.py index 34886828a58..52a3e1de1eb 100644 --- a/homeassistant/components/synology_dsm/coordinator.py +++ b/homeassistant/components/synology_dsm/coordinator.py @@ -2,9 +2,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable, Coroutine from datetime import timedelta import logging -from typing import Any, TypeVar +from typing import Any, Concatenate, ParamSpec, TypeVar from synology_dsm.api.surveillance_station.camera import SynoCamera from synology_dsm.exceptions import ( @@ -30,6 +31,36 @@ _LOGGER = logging.getLogger(__name__) _DataT = TypeVar("_DataT") +_T = TypeVar("_T", bound="SynologyDSMUpdateCoordinator") +_P = ParamSpec("_P") + + +def async_re_login_on_expired( + func: Callable[Concatenate[_T, _P], Awaitable[_DataT]], +) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, _DataT]]: + """Define a wrapper to re-login when expired.""" + + async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _DataT: + for attempts in range(2): + try: + return await func(self, *args, **kwargs) + except SynologyDSMNotLoggedInException: + # If login is expired, try to login again + _LOGGER.debug("login is expired, try to login again") + try: + await self.api.async_login() + except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err: + raise_config_entry_auth_error(err) + if attempts == 0: + continue + except SYNOLOGY_CONNECTION_EXCEPTIONS as err: + raise UpdateFailed(f"Error communicating with API: {err}") from err + + raise UpdateFailed("Unknown error when communicating with API") + + return _async_wrap + + class SynologyDSMUpdateCoordinator(DataUpdateCoordinator[_DataT]): """DataUpdateCoordinator base class for synology_dsm.""" @@ -72,6 +103,7 @@ class SynologyDSMSwitchUpdateCoordinator( assert info is not None self.version = info["data"]["CMSMinVersion"] + @async_re_login_on_expired async def _async_update_data(self) -> dict[str, dict[str, Any]]: """Fetch all data from api.""" surveillance_station = self.api.surveillance_station @@ -102,21 +134,10 @@ class SynologyDSMCentralUpdateCoordinator(SynologyDSMUpdateCoordinator[None]): ), ) + @async_re_login_on_expired async def _async_update_data(self) -> None: """Fetch all data from api.""" - for attempts in range(2): - try: - await self.api.async_update() - except SynologyDSMNotLoggedInException: - # If login is expired, try to login again - try: - await self.api.dsm.login() - except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err: - raise_config_entry_auth_error(err) - if attempts == 0: - continue - except SYNOLOGY_CONNECTION_EXCEPTIONS as err: - raise UpdateFailed(f"Error communicating with API: {err}") from err + await self.api.async_update() class SynologyDSMCameraUpdateCoordinator( @@ -133,6 +154,7 @@ class SynologyDSMCameraUpdateCoordinator( """Initialize DataUpdateCoordinator for cameras.""" super().__init__(hass, entry, api, timedelta(seconds=30)) + @async_re_login_on_expired async def _async_update_data(self) -> dict[str, dict[int, SynoCamera]]: """Fetch all camera data from api.""" surveillance_station = self.api.surveillance_station