diff --git a/.coveragerc b/.coveragerc index c301d4e30d3..78ef0996c53 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1193,6 +1193,7 @@ omit = homeassistant/components/tradfri/__init__.py homeassistant/components/tradfri/base_class.py homeassistant/components/tradfri/config_flow.py + homeassistant/components/tradfri/coordinator.py homeassistant/components/tradfri/cover.py homeassistant/components/tradfri/fan.py homeassistant/components/tradfri/light.py diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index 952ee54a0d8..6dd3236ea97 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -7,6 +7,9 @@ from typing import Any from pytradfri import Gateway, PytradfriError, RequestError from pytradfri.api.aiocoap_api import APIFactory +from pytradfri.command import Command +from pytradfri.device import Device +from pytradfri.group import Group import voluptuous as vol from homeassistant import config_entries @@ -15,7 +18,10 @@ from homeassistant.const import CONF_HOST, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.dispatcher import ( + async_dispatcher_connect, + async_dispatcher_send, +) from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType @@ -28,15 +34,20 @@ from .const import ( CONF_IDENTITY, CONF_IMPORT_GROUPS, CONF_KEY, + COORDINATOR, + COORDINATOR_LIST, DEFAULT_ALLOW_TRADFRI_GROUPS, - DEVICES, DOMAIN, - GROUPS, + GROUPS_LIST, KEY_API, PLATFORMS, SIGNAL_GW, TIMEOUT_API, ) +from .coordinator import ( + TradfriDeviceDataUpdateCoordinator, + TradfriGroupDataUpdateCoordinator, +) _LOGGER = logging.getLogger(__name__) @@ -84,9 +95,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry( + hass: HomeAssistant, + entry: ConfigEntry, +) -> bool: """Create a gateway.""" - # host, identity, key, allow_tradfri_groups tradfri_data: dict[str, Any] = {} hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data listeners = tradfri_data[LISTENERS] = [] @@ -96,11 +109,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: psk_id=entry.data[CONF_IDENTITY], psk=entry.data[CONF_KEY], ) + tradfri_data[FACTORY] = factory # Used for async_unload_entry async def on_hass_stop(event: Event) -> None: """Close connection when hass stops.""" await factory.shutdown() + # Setup listeners listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)) api = factory.request @@ -108,19 +123,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: try: gateway_info = await api(gateway.get_gateway_info(), timeout=TIMEOUT_API) - devices_commands = await api(gateway.get_devices(), timeout=TIMEOUT_API) - devices = await api(devices_commands, timeout=TIMEOUT_API) - groups_commands = await api(gateway.get_groups(), timeout=TIMEOUT_API) - groups = await api(groups_commands, timeout=TIMEOUT_API) + devices_commands: Command = await api( + gateway.get_devices(), timeout=TIMEOUT_API + ) + devices: list[Device] = await api(devices_commands, timeout=TIMEOUT_API) + groups_commands: Command = await api(gateway.get_groups(), timeout=TIMEOUT_API) + groups: list[Group] = await api(groups_commands, timeout=TIMEOUT_API) + except PytradfriError as exc: await factory.shutdown() raise ConfigEntryNotReady from exc - tradfri_data[KEY_API] = api - tradfri_data[FACTORY] = factory - tradfri_data[DEVICES] = devices - tradfri_data[GROUPS] = groups - dev_reg = await hass.helpers.device_registry.async_get_registry() dev_reg.async_get_or_create( config_entry_id=entry.entry_id, @@ -133,7 +146,38 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: sw_version=gateway_info.firmware_version, ) - hass.config_entries.async_setup_platforms(entry, PLATFORMS) + # Setup the device coordinators + coordinator_data = { + CONF_GATEWAY_ID: gateway, + KEY_API: api, + COORDINATOR_LIST: [], + GROUPS_LIST: [], + } + + for device in devices: + coordinator = TradfriDeviceDataUpdateCoordinator( + hass=hass, api=api, device=device + ) + await coordinator.async_config_entry_first_refresh() + + entry.async_on_unload( + async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available) + ) + coordinator_data[COORDINATOR_LIST].append(coordinator) + + for group in groups: + group_coordinator = TradfriGroupDataUpdateCoordinator( + hass=hass, api=api, group=group + ) + await group_coordinator.async_config_entry_first_refresh() + entry.async_on_unload( + async_dispatcher_connect( + hass, SIGNAL_GW, group_coordinator.set_hub_available + ) + ) + coordinator_data[GROUPS_LIST].append(group_coordinator) + + tradfri_data[COORDINATOR] = coordinator_data async def async_keep_alive(now: datetime) -> None: if hass.is_stopping: @@ -152,6 +196,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async_track_time_interval(hass, async_keep_alive, timedelta(seconds=60)) ) + hass.config_entries.async_setup_platforms(entry, PLATFORMS) + return True diff --git a/homeassistant/components/tradfri/base_class.py b/homeassistant/components/tradfri/base_class.py index 34ad7b792b9..af923538fb2 100644 --- a/homeassistant/components/tradfri/base_class.py +++ b/homeassistant/components/tradfri/base_class.py @@ -1,29 +1,22 @@ """Base class for IKEA TRADFRI.""" from __future__ import annotations +from abc import abstractmethod from collections.abc import Callable from functools import wraps import logging -from typing import Any +from typing import Any, cast from pytradfri.command import Command from pytradfri.device import Device -from pytradfri.device.air_purifier import AirPurifier -from pytradfri.device.air_purifier_control import AirPurifierControl -from pytradfri.device.blind import Blind -from pytradfri.device.blind_control import BlindControl -from pytradfri.device.light import Light -from pytradfri.device.light_control import LightControl -from pytradfri.device.signal_repeater_control import SignalRepeaterControl -from pytradfri.device.socket import Socket -from pytradfri.device.socket_control import SocketControl from pytradfri.error import PytradfriError from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo, Entity +from homeassistant.helpers.entity import DeviceInfo +from homeassistant.helpers.update_coordinator import CoordinatorEntity -from .const import DOMAIN, SIGNAL_GW +from .const import DOMAIN +from .coordinator import TradfriDeviceDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) @@ -44,102 +37,44 @@ def handle_error( return wrapper -class TradfriBaseClass(Entity): - """Base class for IKEA TRADFRI. +class TradfriBaseEntity(CoordinatorEntity): + """Base Tradfri device.""" - All devices and groups should ultimately inherit from this class. - """ - - _attr_should_poll = False + coordinator: TradfriDeviceDataUpdateCoordinator def __init__( self, - device: Device, - api: Callable[[Command | list[Command]], Any], + device_coordinator: TradfriDeviceDataUpdateCoordinator, gateway_id: str, + api: Callable[[Command | list[Command]], Any], ) -> None: """Initialize a device.""" - self._api = handle_error(api) - self._attr_name = device.name - self._device: Device = device - self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | AirPurifierControl | None = ( - None - ) - self._device_data: Socket | Light | Blind | AirPurifier | None = None + super().__init__(device_coordinator) + self._gateway_id = gateway_id - async def _async_run_observe(self, cmd: Command) -> None: - """Run observe in a coroutine.""" - try: - await self._api(cmd) - except PytradfriError as err: - self._attr_available = False - self.async_write_ha_state() - _LOGGER.warning("Observation failed, trying again", exc_info=err) - self._async_start_observe() + self._device: Device = device_coordinator.data + + self._device_id = self._device.id + self._api = handle_error(api) + self._attr_name = self._device.name + + self._attr_unique_id = f"{self._gateway_id}-{self._device.id}" + + @abstractmethod + @callback + def _refresh(self) -> None: + """Refresh device data.""" @callback - def _async_start_observe(self, exc: Exception | None = None) -> None: - """Start observation of device.""" - if exc: - self._attr_available = False - self.async_write_ha_state() - _LOGGER.warning("Observation failed for %s", self._attr_name, exc_info=exc) - cmd = self._device.observe( - callback=self._observe_update, - err_callback=self._async_start_observe, - duration=0, - ) - self.hass.async_create_task(self._async_run_observe(cmd)) + def _handle_coordinator_update(self) -> None: + """ + Handle updated data from the coordinator. - async def async_added_to_hass(self) -> None: - """Start thread when added to hass.""" - self._async_start_observe() - - @callback - def _observe_update(self, device: Device) -> None: - """Receive new state data for this device.""" - self._refresh(device) - - def _refresh(self, device: Device, write_ha: bool = True) -> None: - """Refresh the device data.""" - self._device = device - self._attr_name = device.name - if write_ha: - self.async_write_ha_state() - - -class TradfriBaseDevice(TradfriBaseClass): - """Base class for a TRADFRI device. - - All devices should inherit from this class. - """ - - def __init__( - self, - device: Device, - api: Callable[[Command | list[Command]], Any], - gateway_id: str, - ) -> None: - """Initialize a device.""" - self._attr_available = device.reachable - self._hub_available = True - super().__init__(device, api, gateway_id) - - async def async_added_to_hass(self) -> None: - """Start thread when added to hass.""" - # Only devices shall receive SIGNAL_GW - self.async_on_remove( - async_dispatcher_connect(self.hass, SIGNAL_GW, self.set_hub_available) - ) - await super().async_added_to_hass() - - @callback - def set_hub_available(self, available: bool) -> None: - """Set status of hub.""" - if available != self._hub_available: - self._hub_available = available - self._refresh(self._device) + Tests fails without this method. + """ + self._refresh() + super()._handle_coordinator_update() @property def device_info(self) -> DeviceInfo: @@ -154,10 +89,7 @@ class TradfriBaseDevice(TradfriBaseClass): via_device=(DOMAIN, self._gateway_id), ) - def _refresh(self, device: Device, write_ha: bool = True) -> None: - """Refresh the device data.""" - # The base class _refresh cannot be used, because - # there are devices (group) that do not have .reachable - # so set _attr_available here and let the base class do the rest. - self._attr_available = device.reachable and self._hub_available - super()._refresh(device, write_ha) + @property + def available(self) -> bool: + """Return if entity is available.""" + return cast(bool, self._device.reachable) and super().available diff --git a/homeassistant/components/tradfri/const.py b/homeassistant/components/tradfri/const.py index 487eb0dae13..c87d2097929 100644 --- a/homeassistant/components/tradfri/const.py +++ b/homeassistant/components/tradfri/const.py @@ -37,3 +37,9 @@ PLATFORMS = [ ] TIMEOUT_API = 30 ATTR_MAX_FAN_STEPS = 49 + +SCAN_INTERVAL = 60 # Interval for updating the coordinator + +COORDINATOR = "coordinator" +COORDINATOR_LIST = "coordinator_list" +GROUPS_LIST = "groups_list" diff --git a/homeassistant/components/tradfri/coordinator.py b/homeassistant/components/tradfri/coordinator.py new file mode 100644 index 00000000000..1395478b6e9 --- /dev/null +++ b/homeassistant/components/tradfri/coordinator.py @@ -0,0 +1,145 @@ +"""Tradfri DataUpdateCoordinator.""" +from __future__ import annotations + +from collections.abc import Callable +from datetime import timedelta +import logging +from typing import Any + +from pytradfri.command import Command +from pytradfri.device import Device +from pytradfri.error import RequestError +from pytradfri.group import Group + +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .const import SCAN_INTERVAL + +_LOGGER = logging.getLogger(__name__) + + +class TradfriDeviceDataUpdateCoordinator(DataUpdateCoordinator[Device]): + """Coordinator to manage data for a specific Tradfri device.""" + + def __init__( + self, + hass: HomeAssistant, + *, + api: Callable[[Command | list[Command]], Any], + device: Device, + ) -> None: + """Initialize device coordinator.""" + self.api = api + self.device = device + self._exception: Exception | None = None + + super().__init__( + hass, + _LOGGER, + name=f"Update coordinator for {device}", + update_interval=timedelta(seconds=SCAN_INTERVAL), + ) + + async def set_hub_available(self, available: bool) -> None: + """Set status of hub.""" + if available != self.last_update_success: + if not available: + self.last_update_success = False + await self.async_request_refresh() + + @callback + def _observe_update(self, device: Device) -> None: + """Update the coordinator for a device when a change is detected.""" + self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval + + self.async_set_updated_data(data=device) + + @callback + def _exception_callback(self, device: Device, exc: Exception | None = None) -> None: + """Schedule handling exception..""" + self.hass.async_create_task(self._handle_exception(device=device, exc=exc)) + + async def _handle_exception( + self, device: Device, exc: Exception | None = None + ) -> None: + """Handle observe exceptions in a coroutine.""" + self._exception = ( + exc # Store exception so that it gets raised in _async_update_data + ) + + _LOGGER.debug("Observation failed for %s, trying again", device, exc_info=exc) + self.update_interval = timedelta( + seconds=5 + ) # Change interval so we get a swift refresh + await self.async_request_refresh() + + async def _async_update_data(self) -> Device: + """Fetch data from the gateway for a specific device.""" + try: + if self._exception: + exc = self._exception + self._exception = None # Clear stored exception + raise exc # pylint: disable-msg=raising-bad-type + except RequestError as err: + raise UpdateFailed( + f"Error communicating with API: {err}. Try unplugging and replugging your " + f"IKEA gateway." + ) from err + + if not self.data or not self.last_update_success: # Start subscription + try: + cmd = self.device.observe( + callback=self._observe_update, + err_callback=self._exception_callback, + duration=0, + ) + await self.api(cmd) + except RequestError as exc: + await self._handle_exception(device=self.device, exc=exc) + + return self.device + + +class TradfriGroupDataUpdateCoordinator(DataUpdateCoordinator[Group]): + """Coordinator to manage data for a specific Tradfri group.""" + + def __init__( + self, + hass: HomeAssistant, + *, + api: Callable[[Command | list[Command]], Any], + group: Group, + ) -> None: + """Initialize group coordinator.""" + self.api = api + self.group = group + self._exception: Exception | None = None + + super().__init__( + hass, + _LOGGER, + name=f"Update coordinator for {group}", + update_interval=timedelta(seconds=SCAN_INTERVAL), + ) + + async def set_hub_available(self, available: bool) -> None: + """Set status of hub.""" + if available != self.last_update_success: + if not available: + self.last_update_success = False + await self.async_request_refresh() + + async def _async_update_data(self) -> Group: + """Fetch data from the gateway for a specific group.""" + self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval + cmd = self.group.update() + try: + await self.api(cmd) + except RequestError as exc: + self.update_interval = timedelta( + seconds=5 + ) # Change interval so we get a swift refresh + raise UpdateFailed("Unable to update group coordinator") from exc + + return self.group diff --git a/homeassistant/components/tradfri/cover.py b/homeassistant/components/tradfri/cover.py index 554650f9005..d4c3063f35d 100644 --- a/homeassistant/components/tradfri/cover.py +++ b/homeassistant/components/tradfri/cover.py @@ -11,8 +11,16 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .base_class import TradfriBaseDevice -from .const import ATTR_MODEL, CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API +from .base_class import TradfriBaseEntity +from .const import ( + ATTR_MODEL, + CONF_GATEWAY_ID, + COORDINATOR, + COORDINATOR_LIST, + DOMAIN, + KEY_API, +) +from .coordinator import TradfriDeviceDataUpdateCoordinator async def async_setup_entry( @@ -22,28 +30,42 @@ async def async_setup_entry( ) -> None: """Load Tradfri covers based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - tradfri_data = hass.data[DOMAIN][config_entry.entry_id] - api = tradfri_data[KEY_API] - devices = tradfri_data[DEVICES] + coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] + api = coordinator_data[KEY_API] async_add_entities( - TradfriCover(dev, api, gateway_id) for dev in devices if dev.has_blind_control + TradfriCover( + device_coordinator, + api, + gateway_id, + ) + for device_coordinator in coordinator_data[COORDINATOR_LIST] + if device_coordinator.device.has_blind_control ) -class TradfriCover(TradfriBaseDevice, CoverEntity): +class TradfriCover(TradfriBaseEntity, CoverEntity): """The platform class required by Home Assistant.""" def __init__( self, - device: Command, + device_coordinator: TradfriDeviceDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: - """Initialize a cover.""" - self._attr_unique_id = f"{gateway_id}-{device.id}" - super().__init__(device, api, gateway_id) - self._refresh(device, write_ha=False) + """Initialize a switch.""" + super().__init__( + device_coordinator=device_coordinator, + api=api, + gateway_id=gateway_id, + ) + + self._device_control = self._device.blind_control + self._device_data = self._device_control.blinds[0] + + def _refresh(self) -> None: + """Refresh the device.""" + self._device_data = self.coordinator.data.blind_control.blinds[0] @property def extra_state_attributes(self) -> dict[str, str] | None: @@ -88,11 +110,3 @@ class TradfriCover(TradfriBaseDevice, CoverEntity): def is_closed(self) -> bool: """Return if the cover is closed or not.""" return self.current_cover_position == 0 - - def _refresh(self, device: Command, write_ha: bool = True) -> None: - """Refresh the cover data.""" - # Caching of BlindControl and cover object - self._device = device - self._device_control = device.blind_control - self._device_data = device.blind_control.blinds[0] - super()._refresh(device, write_ha=write_ha) diff --git a/homeassistant/components/tradfri/fan.py b/homeassistant/components/tradfri/fan.py index 7b64e883c44..36e1c8b08ad 100644 --- a/homeassistant/components/tradfri/fan.py +++ b/homeassistant/components/tradfri/fan.py @@ -16,15 +16,17 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .base_class import TradfriBaseDevice +from .base_class import TradfriBaseEntity from .const import ( ATTR_AUTO, ATTR_MAX_FAN_STEPS, CONF_GATEWAY_ID, - DEVICES, + COORDINATOR, + COORDINATOR_LIST, DOMAIN, KEY_API, ) +from .coordinator import TradfriDeviceDataUpdateCoordinator def _from_fan_percentage(percentage: int) -> int: @@ -44,30 +46,42 @@ async def async_setup_entry( ) -> None: """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - tradfri_data = hass.data[DOMAIN][config_entry.entry_id] - api = tradfri_data[KEY_API] - devices = tradfri_data[DEVICES] + coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] + api = coordinator_data[KEY_API] async_add_entities( - TradfriAirPurifierFan(dev, api, gateway_id) - for dev in devices - if dev.has_air_purifier_control + TradfriAirPurifierFan( + device_coordinator, + api, + gateway_id, + ) + for device_coordinator in coordinator_data[COORDINATOR_LIST] + if device_coordinator.device.has_air_purifier_control ) -class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity): +class TradfriAirPurifierFan(TradfriBaseEntity, FanEntity): """The platform class required by Home Assistant.""" def __init__( self, - device: Command, + device_coordinator: TradfriDeviceDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: """Initialize a switch.""" - super().__init__(device, api, gateway_id) - self._attr_unique_id = f"{gateway_id}-{device.id}" - self._refresh(device, write_ha=False) + super().__init__( + device_coordinator=device_coordinator, + api=api, + gateway_id=gateway_id, + ) + + self._device_control = self._device.air_purifier_control + self._device_data = self._device_control.air_purifiers[0] + + def _refresh(self) -> None: + """Refresh the device.""" + self._device_data = self.coordinator.data.air_purifier_control.air_purifiers[0] @property def supported_features(self) -> int: @@ -168,10 +182,3 @@ class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity): if not self._device_control: return await self._api(self._device_control.turn_off()) - - def _refresh(self, device: Command, write_ha: bool = True) -> None: - """Refresh the purifier data.""" - # Caching of air purifier control and purifier object - self._device_control = device.air_purifier_control - self._device_data = device.air_purifier_control.air_purifiers[0] - super()._refresh(device, write_ha=write_ha) diff --git a/homeassistant/components/tradfri/light.py b/homeassistant/components/tradfri/light.py index ca078a37e81..9b6ad3e9f06 100644 --- a/homeassistant/components/tradfri/light.py +++ b/homeassistant/components/tradfri/light.py @@ -5,6 +5,7 @@ from collections.abc import Callable from typing import Any, cast from pytradfri.command import Command +from pytradfri.group import Group from homeassistant.components.light import ( ATTR_BRIGHTNESS, @@ -19,9 +20,10 @@ from homeassistant.components.light import ( from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import CoordinatorEntity import homeassistant.util.color as color_util -from .base_class import TradfriBaseClass, TradfriBaseDevice +from .base_class import TradfriBaseEntity from .const import ( ATTR_DIMMER, ATTR_HUE, @@ -29,13 +31,18 @@ from .const import ( ATTR_TRANSITION_TIME, CONF_GATEWAY_ID, CONF_IMPORT_GROUPS, - DEVICES, + COORDINATOR, + COORDINATOR_LIST, DOMAIN, - GROUPS, + GROUPS_LIST, KEY_API, SUPPORTED_GROUP_FEATURES, SUPPORTED_LIGHT_FEATURES, ) +from .coordinator import ( + TradfriDeviceDataUpdateCoordinator, + TradfriGroupDataUpdateCoordinator, +) async def async_setup_entry( @@ -45,56 +52,66 @@ async def async_setup_entry( ) -> None: """Load Tradfri lights based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - tradfri_data = hass.data[DOMAIN][config_entry.entry_id] - api = tradfri_data[KEY_API] - devices = tradfri_data[DEVICES] + coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] + api = coordinator_data[KEY_API] - entities: list[TradfriBaseClass] = [ - TradfriLight(dev, api, gateway_id) for dev in devices if dev.has_light_control + entities: list = [ + TradfriLight( + device_coordinator, + api, + gateway_id, + ) + for device_coordinator in coordinator_data[COORDINATOR_LIST] + if device_coordinator.device.has_light_control ] - if config_entry.data[CONF_IMPORT_GROUPS] and (groups := tradfri_data[GROUPS]): - entities.extend([TradfriGroup(group, api, gateway_id) for group in groups]) + + if config_entry.data[CONF_IMPORT_GROUPS] and ( + group_coordinators := coordinator_data[GROUPS_LIST] + ): + entities.extend( + [ + TradfriGroup(group_coordinator, api, gateway_id) + for group_coordinator in group_coordinators + ] + ) + async_add_entities(entities) -class TradfriGroup(TradfriBaseClass, LightEntity): +class TradfriGroup(CoordinatorEntity, LightEntity): """The platform class for light groups required by hass.""" _attr_supported_features = SUPPORTED_GROUP_FEATURES def __init__( self, - device: Command, + group_coordinator: TradfriGroupDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: """Initialize a Group.""" - super().__init__(device, api, gateway_id) + super().__init__(coordinator=group_coordinator) - self._attr_unique_id = f"group-{gateway_id}-{device.id}" - self._attr_should_poll = True - self._refresh(device, write_ha=False) + self._group: Group = self.coordinator.data - async def async_update(self) -> None: - """Fetch new state data for the group. - - This method is required for groups to update properly. - """ - await self._api(self._device.update()) + self._api = api + self._attr_unique_id = f"group-{gateway_id}-{self._group.id}" @property def is_on(self) -> bool: """Return true if group lights are on.""" - return cast(bool, self._device.state) + return cast(bool, self._group.state) @property def brightness(self) -> int | None: """Return the brightness of the group lights.""" - return cast(int, self._device.dimmer) + return cast(int, self._group.dimmer) async def async_turn_off(self, **kwargs: Any) -> None: """Instruct the group lights to turn off.""" - await self._api(self._device.set_state(0)) + await self._api(self._group.set_state(0)) + + await self.coordinator.async_request_refresh() async def async_turn_on(self, **kwargs: Any) -> None: """Instruct the group lights to turn on, or dim.""" @@ -106,39 +123,53 @@ class TradfriGroup(TradfriBaseClass, LightEntity): if kwargs[ATTR_BRIGHTNESS] == 255: kwargs[ATTR_BRIGHTNESS] = 254 - await self._api(self._device.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys)) + await self._api(self._group.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys)) else: - await self._api(self._device.set_state(1)) + await self._api(self._group.set_state(1)) + + await self.coordinator.async_request_refresh() -class TradfriLight(TradfriBaseDevice, LightEntity): +class TradfriLight(TradfriBaseEntity, LightEntity): """The platform class required by Home Assistant.""" def __init__( self, - device: Command, + device_coordinator: TradfriDeviceDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: """Initialize a Light.""" - super().__init__(device, api, gateway_id) - self._attr_unique_id = f"light-{gateway_id}-{device.id}" + super().__init__( + device_coordinator=device_coordinator, + api=api, + gateway_id=gateway_id, + ) + + self._device_control = self._device.light_control + self._device_data = self._device_control.lights[0] + + self._attr_unique_id = f"light-{gateway_id}-{self._device_id}" self._hs_color = None # Calculate supported features _features = SUPPORTED_LIGHT_FEATURES - if device.light_control.can_set_dimmer: + if self._device.light_control.can_set_dimmer: _features |= SUPPORT_BRIGHTNESS - if device.light_control.can_set_color: + if self._device.light_control.can_set_color: _features |= SUPPORT_COLOR | SUPPORT_COLOR_TEMP - if device.light_control.can_set_temp: + if self._device.light_control.can_set_temp: _features |= SUPPORT_COLOR_TEMP self._attr_supported_features = _features - self._refresh(device, write_ha=False) + if self._device_control: self._attr_min_mireds = self._device_control.min_mireds self._attr_max_mireds = self._device_control.max_mireds + def _refresh(self) -> None: + """Refresh the device.""" + self._device_data = self.coordinator.data.light_control.lights[0] + @property def is_on(self) -> bool: """Return true if light is on.""" @@ -268,10 +299,3 @@ class TradfriLight(TradfriBaseDevice, LightEntity): await self._api(temp_command) if command is not None: await self._api(command) - - def _refresh(self, device: Command, write_ha: bool = True) -> None: - """Refresh the light data.""" - # Caching of LightControl and light object - self._device_control = device.light_control - self._device_data = device.light_control.lights[0] - super()._refresh(device, write_ha=write_ha) diff --git a/homeassistant/components/tradfri/sensor.py b/homeassistant/components/tradfri/sensor.py index 5da3179c507..3d042fa6417 100644 --- a/homeassistant/components/tradfri/sensor.py +++ b/homeassistant/components/tradfri/sensor.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Any, cast +from typing import Any from pytradfri.command import Command @@ -12,8 +12,9 @@ from homeassistant.const import PERCENTAGE from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API +from .base_class import TradfriBaseEntity +from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .coordinator import TradfriDeviceDataUpdateCoordinator async def async_setup_entry( @@ -23,24 +24,27 @@ async def async_setup_entry( ) -> None: """Set up a Tradfri config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - tradfri_data = hass.data[DOMAIN][config_entry.entry_id] - api = tradfri_data[KEY_API] - devices = tradfri_data[DEVICES] + coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] + api = coordinator_data[KEY_API] async_add_entities( - TradfriSensor(dev, api, gateway_id) - for dev in devices + TradfriSensor( + device_coordinator, + api, + gateway_id, + ) + for device_coordinator in coordinator_data[COORDINATOR_LIST] if ( - not dev.has_light_control - and not dev.has_socket_control - and not dev.has_blind_control - and not dev.has_signal_repeater_control - and not dev.has_air_purifier_control + not device_coordinator.device.has_light_control + and not device_coordinator.device.has_socket_control + and not device_coordinator.device.has_blind_control + and not device_coordinator.device.has_signal_repeater_control + and not device_coordinator.device.has_air_purifier_control ) ) -class TradfriSensor(TradfriBaseDevice, SensorEntity): +class TradfriSensor(TradfriBaseEntity, SensorEntity): """The platform class required by Home Assistant.""" _attr_device_class = SensorDeviceClass.BATTERY @@ -48,17 +52,19 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity): def __init__( self, - device: Command, + device_coordinator: TradfriDeviceDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: - """Initialize the device.""" - super().__init__(device, api, gateway_id) - self._attr_unique_id = f"{gateway_id}-{device.id}" + """Initialize a switch.""" + super().__init__( + device_coordinator=device_coordinator, + api=api, + gateway_id=gateway_id, + ) - @property - def native_value(self) -> int | None: - """Return the current state of the device.""" - if not self._device: - return None - return cast(int, self._device.device_info.battery_level) + self._refresh() # Set initial state + + def _refresh(self) -> None: + """Refresh the device.""" + self._attr_native_value = self.coordinator.data.device_info.battery_level diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index f8950d24720..e0e2467ca4b 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -11,8 +11,9 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API +from .base_class import TradfriBaseEntity +from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .coordinator import TradfriDeviceDataUpdateCoordinator async def async_setup_entry( @@ -22,35 +23,42 @@ async def async_setup_entry( ) -> None: """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - tradfri_data = hass.data[DOMAIN][config_entry.entry_id] - api = tradfri_data[KEY_API] - devices = tradfri_data[DEVICES] + coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] + api = coordinator_data[KEY_API] async_add_entities( - TradfriSwitch(dev, api, gateway_id) for dev in devices if dev.has_socket_control + TradfriSwitch( + device_coordinator, + api, + gateway_id, + ) + for device_coordinator in coordinator_data[COORDINATOR_LIST] + if device_coordinator.device.has_socket_control ) -class TradfriSwitch(TradfriBaseDevice, SwitchEntity): +class TradfriSwitch(TradfriBaseEntity, SwitchEntity): """The platform class required by Home Assistant.""" def __init__( self, - device: Command, + device_coordinator: TradfriDeviceDataUpdateCoordinator, api: Callable[[Command | list[Command]], Any], gateway_id: str, ) -> None: """Initialize a switch.""" - super().__init__(device, api, gateway_id) - self._attr_unique_id = f"{gateway_id}-{device.id}" - self._refresh(device, write_ha=False) + super().__init__( + device_coordinator=device_coordinator, + api=api, + gateway_id=gateway_id, + ) - def _refresh(self, device: Command, write_ha: bool = True) -> None: - """Refresh the switch data.""" - # Caching of switch control and switch object - self._device_control = device.socket_control - self._device_data = device.socket_control.sockets[0] - super()._refresh(device, write_ha=write_ha) + self._device_control = self._device.socket_control + self._device_data = self._device_control.sockets[0] + + def _refresh(self) -> None: + """Refresh the device.""" + self._device_data = self.coordinator.data.socket_control.sockets[0] @property def is_on(self) -> bool: diff --git a/tests/components/tradfri/common.py b/tests/components/tradfri/common.py index 5e28bdcd55c..feeb60ab7c9 100644 --- a/tests/components/tradfri/common.py +++ b/tests/components/tradfri/common.py @@ -22,3 +22,5 @@ async def setup_integration(hass): entry.add_to_hass(hass) await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() + + return entry diff --git a/tests/components/tradfri/test_fan.py b/tests/components/tradfri/test_fan.py index 13b7e59e103..4aa99f5778a 100644 --- a/tests/components/tradfri/test_fan.py +++ b/tests/components/tradfri/test_fan.py @@ -121,7 +121,6 @@ async def test_set_percentage( """Test setting speed of a fan.""" # Note pytradfri style, not hass. Values not really important. initial_state = {"percentage": 10, "fan_speed": 3} - # Setup the gateway with a mock fan. fan = mock_fan(test_state=initial_state, device_number=0) mock_gateway.mock_devices.append(fan) diff --git a/tests/components/tradfri/test_light.py b/tests/components/tradfri/test_light.py index 7de2c4dcb37..1ed24d7b080 100644 --- a/tests/components/tradfri/test_light.py +++ b/tests/components/tradfri/test_light.py @@ -317,6 +317,7 @@ def mock_group(test_state=None, group_number=0): _mock_group = Mock(member_ids=[], observe=Mock(), **state) _mock_group.name = f"tradfri_group_{group_number}" + _mock_group.id = group_number return _mock_group @@ -327,11 +328,11 @@ async def test_group(hass, mock_gateway, mock_api_factory): mock_gateway.mock_groups.append(mock_group(state, 1)) await setup_integration(hass) - group = hass.states.get("light.tradfri_group_0") + group = hass.states.get("light.tradfri_group_mock_gateway_id_0") assert group is not None assert group.state == "off" - group = hass.states.get("light.tradfri_group_1") + group = hass.states.get("light.tradfri_group_mock_gateway_id_1") assert group is not None assert group.state == "on" assert group.attributes["brightness"] == 100 @@ -348,19 +349,26 @@ async def test_group_turn_on(hass, mock_gateway, mock_api_factory): await setup_integration(hass) # Use the turn_off service call to change the light state. - await hass.services.async_call( - "light", "turn_on", {"entity_id": "light.tradfri_group_0"}, blocking=True - ) await hass.services.async_call( "light", "turn_on", - {"entity_id": "light.tradfri_group_1", "brightness": 100}, + {"entity_id": "light.tradfri_group_mock_gateway_id_0"}, blocking=True, ) await hass.services.async_call( "light", "turn_on", - {"entity_id": "light.tradfri_group_2", "brightness": 100, "transition": 1}, + {"entity_id": "light.tradfri_group_mock_gateway_id_1", "brightness": 100}, + blocking=True, + ) + await hass.services.async_call( + "light", + "turn_on", + { + "entity_id": "light.tradfri_group_mock_gateway_id_2", + "brightness": 100, + "transition": 1, + }, blocking=True, ) await hass.async_block_till_done() @@ -378,7 +386,10 @@ async def test_group_turn_off(hass, mock_gateway, mock_api_factory): # Use the turn_off service call to change the light state. await hass.services.async_call( - "light", "turn_off", {"entity_id": "light.tradfri_group_0"}, blocking=True + "light", + "turn_off", + {"entity_id": "light.tradfri_group_mock_gateway_id_0"}, + blocking=True, ) await hass.async_block_till_done()