Ensure all synology_dsm coordinators handle expired sessions (#116796)

* Ensure all synology_dsm coordinators handle expired sessions

* Ensure all synology_dsm coordinators handle expired sessions

* Ensure all synology_dsm coordinators handle expired sessions

* handle cancellation

* add a debug log message

---------

Co-authored-by: mib1185 <mail@mib85.de>
This commit is contained in:
J. Nick Koston 2024-05-05 05:09:57 -05:00 committed by GitHub
parent f5394dc3a3
commit b4bac7705e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 16 deletions

View File

@ -7,6 +7,7 @@ import logging
from synology_dsm.api.surveillance_station import SynoSurveillanceStation
from synology_dsm.api.surveillance_station.camera import SynoCamera
from synology_dsm.exceptions import SynologyDSMNotLoggedInException
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_MAC, CONF_VERIFY_SSL
@ -69,7 +70,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await api.async_setup()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err)
except SYNOLOGY_CONNECTION_EXCEPTIONS as err:
except (*SYNOLOGY_CONNECTION_EXCEPTIONS, SynologyDSMNotLoggedInException) as err:
# SynologyDSMNotLoggedInException may be raised even if the user is
# logged in because the session may have expired, and we need to retry
# the login later.
if err.args[0] and isinstance(err.args[0], dict):
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
else:

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from contextlib import suppress
import logging
@ -82,6 +83,31 @@ class SynoApi:
self._with_upgrade = True
self._with_utilisation = True
self._login_future: asyncio.Future[None] | None = None
async def async_login(self) -> None:
"""Login to the Synology DSM API.
This function will only login once if called multiple times
by multiple different callers.
If a login is already in progress, the function will await the
login to complete before returning.
"""
if self._login_future:
return await self._login_future
self._login_future = self._hass.loop.create_future()
try:
await self.dsm.login()
self._login_future.set_result(None)
except BaseException as err:
if not self._login_future.done():
self._login_future.set_exception(err)
raise
finally:
self._login_future = None
async def async_setup(self) -> None:
"""Start interacting with the NAS."""
session = async_get_clientsession(self._hass, self._entry.data[CONF_VERIFY_SSL])
@ -95,7 +121,7 @@ class SynoApi:
timeout=self._entry.options.get(CONF_TIMEOUT) or 10,
device_token=self._entry.data.get(CONF_DEVICE_TOKEN),
)
await self.dsm.login()
await self.async_login()
# check if surveillance station is used
self._with_surveillance_station = bool(

View File

@ -2,9 +2,10 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable, Coroutine
from datetime import timedelta
import logging
from typing import Any, TypeVar
from typing import Any, Concatenate, ParamSpec, TypeVar
from synology_dsm.api.surveillance_station.camera import SynoCamera
from synology_dsm.exceptions import (
@ -30,6 +31,36 @@ _LOGGER = logging.getLogger(__name__)
_DataT = TypeVar("_DataT")
_T = TypeVar("_T", bound="SynologyDSMUpdateCoordinator")
_P = ParamSpec("_P")
def async_re_login_on_expired(
func: Callable[Concatenate[_T, _P], Awaitable[_DataT]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, _DataT]]:
"""Define a wrapper to re-login when expired."""
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _DataT:
for attempts in range(2):
try:
return await func(self, *args, **kwargs)
except SynologyDSMNotLoggedInException:
# If login is expired, try to login again
_LOGGER.debug("login is expired, try to login again")
try:
await self.api.async_login()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err)
if attempts == 0:
continue
except SYNOLOGY_CONNECTION_EXCEPTIONS as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
raise UpdateFailed("Unknown error when communicating with API")
return _async_wrap
class SynologyDSMUpdateCoordinator(DataUpdateCoordinator[_DataT]):
"""DataUpdateCoordinator base class for synology_dsm."""
@ -72,6 +103,7 @@ class SynologyDSMSwitchUpdateCoordinator(
assert info is not None
self.version = info["data"]["CMSMinVersion"]
@async_re_login_on_expired
async def _async_update_data(self) -> dict[str, dict[str, Any]]:
"""Fetch all data from api."""
surveillance_station = self.api.surveillance_station
@ -102,21 +134,10 @@ class SynologyDSMCentralUpdateCoordinator(SynologyDSMUpdateCoordinator[None]):
),
)
@async_re_login_on_expired
async def _async_update_data(self) -> None:
"""Fetch all data from api."""
for attempts in range(2):
try:
await self.api.async_update()
except SynologyDSMNotLoggedInException:
# If login is expired, try to login again
try:
await self.api.dsm.login()
except SYNOLOGY_AUTH_FAILED_EXCEPTIONS as err:
raise_config_entry_auth_error(err)
if attempts == 0:
continue
except SYNOLOGY_CONNECTION_EXCEPTIONS as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
await self.api.async_update()
class SynologyDSMCameraUpdateCoordinator(
@ -133,6 +154,7 @@ class SynologyDSMCameraUpdateCoordinator(
"""Initialize DataUpdateCoordinator for cameras."""
super().__init__(hass, entry, api, timedelta(seconds=30))
@async_re_login_on_expired
async def _async_update_data(self) -> dict[str, dict[int, SynoCamera]]:
"""Fetch all camera data from api."""
surveillance_station = self.api.surveillance_station