diff --git a/homeassistant/components/shelly/coordinator.py b/homeassistant/components/shelly/coordinator.py index d206c38f5ab..2d321c8df9d 100644 --- a/homeassistant/components/shelly/coordinator.py +++ b/homeassistant/components/shelly/coordinator.py @@ -5,7 +5,7 @@ import asyncio from collections.abc import Callable, Coroutine from dataclasses import dataclass from datetime import timedelta -from typing import Any, cast +from typing import Any, Generic, TypeVar, cast import aioshelly from aioshelly.ble import async_ensure_ble_enabled, async_stop_scanner @@ -49,12 +49,9 @@ from .const import ( UPDATE_PERIOD_MULTIPLIER, BLEScannerMode, ) -from .utils import ( - device_update_info, - get_block_device_name, - get_rpc_device_name, - get_rpc_device_wakeup_period, -) +from .utils import device_update_info, get_device_name, get_rpc_device_wakeup_period + +_DeviceT = TypeVar("_DeviceT", bound="BlockDevice|RpcDevice") @dataclass @@ -73,34 +70,23 @@ def get_entry_data(hass: HomeAssistant) -> dict[str, ShellyEntryData]: return cast(dict[str, ShellyEntryData], hass.data[DOMAIN][DATA_CONFIG_ENTRY]) -class ShellyBlockCoordinator(DataUpdateCoordinator[None]): - """Coordinator for a Shelly block based device.""" +class ShellyCoordinatorBase(DataUpdateCoordinator[None], Generic[_DeviceT]): + """Coordinator for a Shelly device.""" def __init__( - self, hass: HomeAssistant, entry: ConfigEntry, device: BlockDevice + self, + hass: HomeAssistant, + entry: ConfigEntry, + device: _DeviceT, + update_interval: float, ) -> None: - """Initialize the Shelly block device coordinator.""" - self.device_id: str | None = None - - if sleep_period := entry.data[CONF_SLEEP_PERIOD]: - update_interval = SLEEP_PERIOD_MULTIPLIER * sleep_period - else: - update_interval = ( - UPDATE_PERIOD_MULTIPLIER * device.settings["coiot"]["update_period"] - ) - - device_name = ( - get_block_device_name(device) if device.initialized else entry.title - ) - super().__init__( - hass, - LOGGER, - name=device_name, - update_interval=timedelta(seconds=update_interval), - ) - self.hass = hass + """Initialize the Shelly device coordinator.""" self.entry = entry self.device = device + self.device_id: str | None = None + device_name = get_device_name(device) if device.initialized else entry.title + interval_td = timedelta(seconds=update_interval) + super().__init__(hass, LOGGER, name=device_name, update_interval=interval_td) self._debounced_reload: Debouncer[Coroutine[Any, Any, None]] = Debouncer( hass, @@ -110,24 +96,77 @@ class ShellyBlockCoordinator(DataUpdateCoordinator[None]): function=self._async_reload_entry, ) entry.async_on_unload(self._debounced_reload.async_cancel) + + @property + def model(self) -> str: + """Model of the device.""" + return cast(str, self.entry.data["model"]) + + @property + def mac(self) -> str: + """Mac address of the device.""" + return cast(str, self.entry.unique_id) + + @property + def sw_version(self) -> str: + """Firmware version of the device.""" + return self.device.firmware_version if self.device.initialized else "" + + @property + def sleep_period(self) -> int: + """Sleep period of the device.""" + return self.entry.data.get(CONF_SLEEP_PERIOD, 0) + + def async_setup(self) -> None: + """Set up the coordinator.""" + dev_reg = device_registry.async_get(self.hass) + device_entry = dev_reg.async_get_or_create( + config_entry_id=self.entry.entry_id, + name=self.name, + connections={(device_registry.CONNECTION_NETWORK_MAC, self.mac)}, + manufacturer="Shelly", + model=aioshelly.const.MODEL_NAMES.get(self.model, self.model), + sw_version=self.sw_version, + hw_version=f"gen{self.device.gen} ({self.model})", + configuration_url=f"http://{self.entry.data[CONF_HOST]}", + ) + self.device_id = device_entry.id + + async def _async_reload_entry(self) -> None: + """Reload entry.""" + self._debounced_reload.async_cancel() + LOGGER.debug("Reloading entry %s", self.name) + await self.hass.config_entries.async_reload(self.entry.entry_id) + + +class ShellyBlockCoordinator(ShellyCoordinatorBase[BlockDevice]): + """Coordinator for a Shelly block based device.""" + + def __init__( + self, hass: HomeAssistant, entry: ConfigEntry, device: BlockDevice + ) -> None: + """Initialize the Shelly block device coordinator.""" + self.entry = entry + if self.sleep_period: + update_interval = SLEEP_PERIOD_MULTIPLIER * self.sleep_period + else: + update_interval = ( + UPDATE_PERIOD_MULTIPLIER * device.settings["coiot"]["update_period"] + ) + super().__init__(hass, entry, device, update_interval) + self._last_cfg_changed: int | None = None self._last_mode: str | None = None self._last_effect: int | None = None + self._last_input_events_count: dict = {} entry.async_on_unload( self.async_add_listener(self._async_device_updates_handler) ) - self._last_input_events_count: dict = {} - entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._handle_ha_stop) ) - async def _async_reload_entry(self) -> None: - """Reload entry.""" - LOGGER.debug("Reloading entry %s", self.name) - await self.hass.config_entries.async_reload(self.entry.entry_id) - @callback def _async_device_updates_handler(self) -> None: """Handle device updates.""" @@ -209,10 +248,10 @@ class ShellyBlockCoordinator(DataUpdateCoordinator[None]): async def _async_update_data(self) -> None: """Fetch data.""" - if sleep_period := self.entry.data.get(CONF_SLEEP_PERIOD): + if self.sleep_period: # Sleeping device, no point polling it, just mark it unavailable raise UpdateFailed( - f"Sleeping device did not update within {sleep_period} seconds interval" + f"Sleeping device did not update within {self.sleep_period} seconds interval" ) LOGGER.debug("Polling Shelly Block Device - %s", self.name) @@ -225,35 +264,9 @@ class ShellyBlockCoordinator(DataUpdateCoordinator[None]): else: device_update_info(self.hass, self.device, self.entry) - @property - def model(self) -> str: - """Model of the device.""" - return cast(str, self.entry.data["model"]) - - @property - def mac(self) -> str: - """Mac address of the device.""" - return cast(str, self.entry.unique_id) - - @property - def sw_version(self) -> str: - """Firmware version of the device.""" - return self.device.firmware_version if self.device.initialized else "" - def async_setup(self) -> None: """Set up the coordinator.""" - dev_reg = device_registry.async_get(self.hass) - entry = dev_reg.async_get_or_create( - config_entry_id=self.entry.entry_id, - name=self.name, - connections={(device_registry.CONNECTION_NETWORK_MAC, self.mac)}, - manufacturer="Shelly", - model=aioshelly.const.MODEL_NAMES.get(self.model, self.model), - sw_version=self.sw_version, - hw_version=f"gen{self.device.gen} ({self.model})", - configuration_url=f"http://{self.entry.data[CONF_HOST]}", - ) - self.device_id = entry.id + super().async_setup() self.device.subscribe_updates(self.async_set_updated_data) def shutdown(self) -> None: @@ -267,13 +280,14 @@ class ShellyBlockCoordinator(DataUpdateCoordinator[None]): self.shutdown() -class ShellyRestCoordinator(DataUpdateCoordinator[None]): +class ShellyRestCoordinator(ShellyCoordinatorBase[BlockDevice]): """Coordinator for a Shelly REST device.""" def __init__( self, hass: HomeAssistant, device: BlockDevice, entry: ConfigEntry ) -> None: """Initialize the Shelly REST device coordinator.""" + update_interval = REST_SENSORS_UPDATE_INTERVAL if ( device.settings["device"]["type"] in BATTERY_DEVICES_WITH_PERMANENT_CONNECTION @@ -281,17 +295,7 @@ class ShellyRestCoordinator(DataUpdateCoordinator[None]): update_interval = ( SLEEP_PERIOD_MULTIPLIER * device.settings["coiot"]["update_period"] ) - else: - update_interval = REST_SENSORS_UPDATE_INTERVAL - - super().__init__( - hass, - LOGGER, - name=get_block_device_name(device), - update_interval=timedelta(seconds=update_interval), - ) - self.device = device - self.entry = entry + super().__init__(hass, entry, device, update_interval) async def _async_update_data(self) -> None: """Fetch data.""" @@ -312,64 +316,37 @@ class ShellyRestCoordinator(DataUpdateCoordinator[None]): else: device_update_info(self.hass, self.device, self.entry) - @property - def mac(self) -> str: - """Mac address of the device.""" - return cast(str, self.device.settings["device"]["mac"]) - -class ShellyRpcCoordinator(DataUpdateCoordinator[None]): +class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]): """Coordinator for a Shelly RPC based device.""" def __init__( self, hass: HomeAssistant, entry: ConfigEntry, device: RpcDevice ) -> None: """Initialize the Shelly RPC device coordinator.""" - self.device_id: str | None = None - - if sleep_period := entry.data[CONF_SLEEP_PERIOD]: - update_interval = SLEEP_PERIOD_MULTIPLIER * sleep_period + self.entry = entry + if self.sleep_period: + update_interval = SLEEP_PERIOD_MULTIPLIER * self.sleep_period else: update_interval = RPC_RECONNECT_INTERVAL - device_name = get_rpc_device_name(device) if device.initialized else entry.title - super().__init__( - hass, - LOGGER, - name=device_name, - update_interval=timedelta(seconds=update_interval), - ) - self.entry = entry - self.device = device - self.connected = False + super().__init__(hass, entry, device, update_interval) + self.connected = False self._disconnected_callbacks: list[CALLBACK_TYPE] = [] self._connection_lock = asyncio.Lock() self._event_listeners: list[Callable[[dict[str, Any]], None]] = [] - self._debounced_reload: Debouncer[Coroutine[Any, Any, None]] = Debouncer( - hass, - LOGGER, - cooldown=ENTRY_RELOAD_COOLDOWN, - immediate=False, - function=self._async_reload_entry, - ) - entry.async_on_unload(self._debounced_reload.async_cancel) + entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._handle_ha_stop) ) entry.async_on_unload(entry.add_update_listener(self._async_update_listener)) - async def _async_reload_entry(self) -> None: - """Reload entry.""" - self._debounced_reload.async_cancel() - LOGGER.debug("Reloading entry %s", self.name) - await self.hass.config_entries.async_reload(self.entry.entry_id) - def update_sleep_period(self) -> bool: """Check device sleep period & update if changed.""" if ( not self.device.initialized or not (wakeup_period := get_rpc_device_wakeup_period(self.device.status)) - or wakeup_period == self.entry.data.get(CONF_SLEEP_PERIOD) + or wakeup_period == self.sleep_period ): return False @@ -441,10 +418,10 @@ class ShellyRpcCoordinator(DataUpdateCoordinator[None]): if self.update_sleep_period(): return - if sleep_period := self.entry.data.get(CONF_SLEEP_PERIOD): + if self.sleep_period: # Sleeping device, no point polling it, just mark it unavailable raise UpdateFailed( - f"Sleeping device did not update within {sleep_period} seconds interval" + f"Sleeping device did not update within {self.sleep_period} seconds interval" ) if self.device.connected: return @@ -458,26 +435,11 @@ class ShellyRpcCoordinator(DataUpdateCoordinator[None]): except InvalidAuthError: self.entry.async_start_reauth(self.hass) - @property - def model(self) -> str: - """Model of the device.""" - return cast(str, self.entry.data["model"]) - - @property - def mac(self) -> str: - """Mac address of the device.""" - return cast(str, self.entry.unique_id) - - @property - def sw_version(self) -> str: - """Firmware version of the device.""" - return self.device.firmware_version if self.device.initialized else "" - async def _async_disconnected(self) -> None: """Handle device disconnected.""" - # Sleeping devices send data and disconnects + # Sleeping devices send data and disconnect # There are no disconnect events for sleeping devices - if self.entry.data.get(CONF_SLEEP_PERIOD): + if self.sleep_period: return async with self._connection_lock: @@ -514,7 +476,7 @@ class ShellyRpcCoordinator(DataUpdateCoordinator[None]): This will be executed on connect or when the config entry is updated. """ - if not self.entry.data.get(CONF_SLEEP_PERIOD): + if not self.sleep_period: await self._async_connect_ble_scanner() async def _async_connect_ble_scanner(self) -> None: @@ -555,18 +517,7 @@ class ShellyRpcCoordinator(DataUpdateCoordinator[None]): def async_setup(self) -> None: """Set up the coordinator.""" - dev_reg = device_registry.async_get(self.hass) - entry = dev_reg.async_get_or_create( - config_entry_id=self.entry.entry_id, - name=self.name, - connections={(device_registry.CONNECTION_NETWORK_MAC, self.mac)}, - manufacturer="Shelly", - model=aioshelly.const.MODEL_NAMES.get(self.model, self.model), - sw_version=self.sw_version, - hw_version=f"gen{self.device.gen} ({self.model})", - configuration_url=f"http://{self.entry.data[CONF_HOST]}", - ) - self.device_id = entry.id + super().async_setup() self.device.subscribe_updates(self._async_handle_update) if self.device.initialized: # If we are already initialized, we are connected @@ -585,24 +536,14 @@ class ShellyRpcCoordinator(DataUpdateCoordinator[None]): await self.shutdown() -class ShellyRpcPollingCoordinator(DataUpdateCoordinator[None]): +class ShellyRpcPollingCoordinator(ShellyCoordinatorBase[RpcDevice]): """Polling coordinator for a Shelly RPC based device.""" def __init__( self, hass: HomeAssistant, entry: ConfigEntry, device: RpcDevice ) -> None: """Initialize the RPC polling coordinator.""" - self.device_id: str | None = None - - device_name = get_rpc_device_name(device) if device.initialized else entry.title - super().__init__( - hass, - LOGGER, - name=device_name, - update_interval=timedelta(seconds=RPC_SENSORS_POLLING_INTERVAL), - ) - self.entry = entry - self.device = device + super().__init__(hass, entry, device, RPC_SENSORS_POLLING_INTERVAL) async def _async_update_data(self) -> None: """Fetch data.""" @@ -617,11 +558,6 @@ class ShellyRpcPollingCoordinator(DataUpdateCoordinator[None]): except InvalidAuthError: self.entry.async_start_reauth(self.hass) - @property - def mac(self) -> str: - """Mac address of the device.""" - return cast(str, self.entry.unique_id) - def get_block_coordinator_by_device_id( hass: HomeAssistant, device_id: str diff --git a/homeassistant/components/shelly/utils.py b/homeassistant/components/shelly/utils.py index b048b219e6b..edfa1d284ed 100644 --- a/homeassistant/components/shelly/utils.py +++ b/homeassistant/components/shelly/utils.py @@ -44,15 +44,23 @@ def async_remove_shelly_entity( def get_block_device_name(device: BlockDevice) -> str: - """Naming for device.""" + """Get Block device name.""" return cast(str, device.settings["name"] or device.settings["device"]["hostname"]) def get_rpc_device_name(device: RpcDevice) -> str: - """Naming for device.""" + """Get RPC device name.""" return cast(str, device.config["sys"]["device"].get("name") or device.hostname) +def get_device_name(device: BlockDevice | RpcDevice) -> str: + """Get device name.""" + if isinstance(device, BlockDevice): + return get_block_device_name(device) + + return get_rpc_device_name(device) + + def get_number_of_channels(device: BlockDevice, block: Block) -> int: """Get number of channels for block type.""" assert isinstance(device.shelly, dict)