Ensure all synology_dsm coordinators handle expired sessions (#116796)

* Ensure all synology_dsm coordinators handle expired sessions

* Ensure all synology_dsm coordinators handle expired sessions

* Ensure all synology_dsm coordinators handle expired sessions

* handle cancellation

* add a debug log message

---------

Co-authored-by: mib1185 <mail@mib85.de>
This commit is contained in:
J. Nick Koston 2024-05-05 05:09:57 -05:00 committed by GitHub
parent f5394dc3a3
commit b4bac7705e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 16 deletions

View File

@ -7,6 +7,7 @@ import logging
from synology_dsm.api.surveillance_station import SynoSurveillanceStation from synology_dsm.api.surveillance_station import SynoSurveillanceStation
from synology_dsm.api.surveillance_station.camera import SynoCamera from synology_dsm.api.surveillance_station.camera import SynoCamera
from synology_dsm.exceptions import SynologyDSMNotLoggedInException
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_MAC, CONF_VERIFY_SSL from homeassistant.const import CONF_MAC, CONF_VERIFY_SSL
@ -69,7 +70,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await api.async_setup() await api.async_setup()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err: except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err) raise_config_entry_auth_error(err)
except SYNOLOGY_CONNECTION_EXCEPTIONS as err: except (*SYNOLOGY_CONNECTION_EXCEPTIONS, SynologyDSMNotLoggedInException) as err:
# SynologyDSMNotLoggedInException may be raised even if the user is
# logged in because the session may have expired, and we need to retry
# the login later.
if err.args[0] and isinstance(err.args[0], dict): if err.args[0] and isinstance(err.args[0], dict):
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN) details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
else: else:

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import logging import logging
@ -82,6 +83,31 @@ class SynoApi:
self._with_upgrade = True self._with_upgrade = True
self._with_utilisation = True self._with_utilisation = True
self._login_future: asyncio.Future[None] | None = None
async def async_login(self) -> None:
"""Login to the Synology DSM API.
This function will only login once if called multiple times
by multiple different callers.
If a login is already in progress, the function will await the
login to complete before returning.
"""
if self._login_future:
return await self._login_future
self._login_future = self._hass.loop.create_future()
try:
await self.dsm.login()
self._login_future.set_result(None)
except BaseException as err:
if not self._login_future.done():
self._login_future.set_exception(err)
raise
finally:
self._login_future = None
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Start interacting with the NAS.""" """Start interacting with the NAS."""
session = async_get_clientsession(self._hass, self._entry.data[CONF_VERIFY_SSL]) session = async_get_clientsession(self._hass, self._entry.data[CONF_VERIFY_SSL])
@ -95,7 +121,7 @@ class SynoApi:
timeout=self._entry.options.get(CONF_TIMEOUT) or 10, timeout=self._entry.options.get(CONF_TIMEOUT) or 10,
device_token=self._entry.data.get(CONF_DEVICE_TOKEN), device_token=self._entry.data.get(CONF_DEVICE_TOKEN),
) )
await self.dsm.login() await self.async_login()
# check if surveillance station is used # check if surveillance station is used
self._with_surveillance_station = bool( self._with_surveillance_station = bool(

View File

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable, Coroutine
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any, TypeVar from typing import Any, Concatenate, ParamSpec, TypeVar
from synology_dsm.api.surveillance_station.camera import SynoCamera from synology_dsm.api.surveillance_station.camera import SynoCamera
from synology_dsm.exceptions import ( from synology_dsm.exceptions import (
@ -30,6 +31,36 @@ _LOGGER = logging.getLogger(__name__)
_DataT = TypeVar("_DataT") _DataT = TypeVar("_DataT")
_T = TypeVar("_T", bound="SynologyDSMUpdateCoordinator")
_P = ParamSpec("_P")
def async_re_login_on_expired(
func: Callable[Concatenate[_T, _P], Awaitable[_DataT]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, _DataT]]:
"""Define a wrapper to re-login when expired."""
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _DataT:
for attempts in range(2):
try:
return await func(self, *args, **kwargs)
except SynologyDSMNotLoggedInException:
# If login is expired, try to login again
_LOGGER.debug("login is expired, try to login again")
try:
await self.api.async_login()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err)
if attempts == 0:
continue
except SYNOLOGY_CONNECTION_EXCEPTIONS as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
raise UpdateFailed("Unknown error when communicating with API")
return _async_wrap
class SynologyDSMUpdateCoordinator(DataUpdateCoordinator[_DataT]): class SynologyDSMUpdateCoordinator(DataUpdateCoordinator[_DataT]):
"""DataUpdateCoordinator base class for synology_dsm.""" """DataUpdateCoordinator base class for synology_dsm."""
@ -72,6 +103,7 @@ class SynologyDSMSwitchUpdateCoordinator(
assert info is not None assert info is not None
self.version = info["data"]["CMSMinVersion"] self.version = info["data"]["CMSMinVersion"]
@async_re_login_on_expired
async def _async_update_data(self) -> dict[str, dict[str, Any]]: async def _async_update_data(self) -> dict[str, dict[str, Any]]:
"""Fetch all data from api.""" """Fetch all data from api."""
surveillance_station = self.api.surveillance_station surveillance_station = self.api.surveillance_station
@ -102,21 +134,10 @@ class SynologyDSMCentralUpdateCoordinator(SynologyDSMUpdateCoordinator[None]):
), ),
) )
@async_re_login_on_expired
async def _async_update_data(self) -> None: async def _async_update_data(self) -> None:
"""Fetch all data from api.""" """Fetch all data from api."""
for attempts in range(2): await self.api.async_update()
try:
await self.api.async_update()
except SynologyDSMNotLoggedInException:
# If login is expired, try to login again
try:
await self.api.dsm.login()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err)
if attempts == 0:
continue
except SYNOLOGY_CONNECTION_EXCEPTIONS as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
class SynologyDSMCameraUpdateCoordinator( class SynologyDSMCameraUpdateCoordinator(
@ -133,6 +154,7 @@ class SynologyDSMCameraUpdateCoordinator(
"""Initialize DataUpdateCoordinator for cameras.""" """Initialize DataUpdateCoordinator for cameras."""
super().__init__(hass, entry, api, timedelta(seconds=30)) super().__init__(hass, entry, api, timedelta(seconds=30))
@async_re_login_on_expired
async def _async_update_data(self) -> dict[str, dict[int, SynoCamera]]: async def _async_update_data(self) -> dict[str, dict[int, SynoCamera]]:
"""Fetch all camera data from api.""" """Fetch all camera data from api."""
surveillance_station = self.api.surveillance_station surveillance_station = self.api.surveillance_station