Implement coordinator class for Tradfri integration (#64166)

* Initial commit coordinator

* More coordinator implementation

* More coordinator implementation

* Allow integration reload

* Move API calls to try/catch block

* Move back fixture

* Remove coordinator test file

* Ensure unchanged file

* Ensure unchanged conftest.py file

* Remove coordinator key check

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Import RequestError

* Move async_setup_platforms to end of setup_entry

* Remove centralised handling of device data and device controllers

* Remove platform_type argument

* Remove exception

* Remove the correct exception

* Refactor coordinator error handling

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove platform type from base class

* Remove timeout context manager

* Refactor exception callback

* Simplify starting device observation

* Update test

* Move observe start into update method

* Remove await self.coordinator.async_request_refresh()

* Refactor cover.py

* Uncomment const.py

* Add back extra_state_attributes

* Update homeassistant/components/tradfri/coordinator.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Refactor switch platform

* Expose switch state

* Refactor sensor platform

* Put back accidentally deleted code

* Add set_hub_available

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Fix tests for fan platform

* Update homeassistant/components/tradfri/base_class.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Update homeassistant/components/tradfri/base_class.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Fix non-working tests

* Refresh sensor state

* Remove commented line

* Add group coordinator

* Add groups during setup

* Refactor light platform

* Fix tests

* Move outside of try...except

* Remove error handler

* Remove unneeded methods

* Update sensor

* Update .coveragerc

* Move signal

* Add signals for groups

* Fix signal

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Patrik Lindgren 2022-01-27 11:12:52 +01:00 committed by GitHub
parent 3daaed1056
commit 9d404b749a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 452 additions and 251 deletions

View File

@ -1193,6 +1193,7 @@ omit =
homeassistant/components/tradfri/__init__.py homeassistant/components/tradfri/__init__.py
homeassistant/components/tradfri/base_class.py homeassistant/components/tradfri/base_class.py
homeassistant/components/tradfri/config_flow.py homeassistant/components/tradfri/config_flow.py
homeassistant/components/tradfri/coordinator.py
homeassistant/components/tradfri/cover.py homeassistant/components/tradfri/cover.py
homeassistant/components/tradfri/fan.py homeassistant/components/tradfri/fan.py
homeassistant/components/tradfri/light.py homeassistant/components/tradfri/light.py

View File

@ -7,6 +7,9 @@ from typing import Any
from pytradfri import Gateway, PytradfriError, RequestError from pytradfri import Gateway, PytradfriError, RequestError
from pytradfri.api.aiocoap_api import APIFactory from pytradfri.api.aiocoap_api import APIFactory
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.group import Group
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -15,7 +18,10 @@ from homeassistant.const import CONF_HOST, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -28,15 +34,20 @@ from .const import (
CONF_IDENTITY, CONF_IDENTITY,
CONF_IMPORT_GROUPS, CONF_IMPORT_GROUPS,
CONF_KEY, CONF_KEY,
COORDINATOR,
COORDINATOR_LIST,
DEFAULT_ALLOW_TRADFRI_GROUPS, DEFAULT_ALLOW_TRADFRI_GROUPS,
DEVICES,
DOMAIN, DOMAIN,
GROUPS, GROUPS_LIST,
KEY_API, KEY_API,
PLATFORMS, PLATFORMS,
SIGNAL_GW, SIGNAL_GW,
TIMEOUT_API, TIMEOUT_API,
) )
from .coordinator import (
TradfriDeviceDataUpdateCoordinator,
TradfriGroupDataUpdateCoordinator,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -84,9 +95,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
) -> bool:
"""Create a gateway.""" """Create a gateway."""
# host, identity, key, allow_tradfri_groups
tradfri_data: dict[str, Any] = {} tradfri_data: dict[str, Any] = {}
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
listeners = tradfri_data[LISTENERS] = [] listeners = tradfri_data[LISTENERS] = []
@ -96,11 +109,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
psk_id=entry.data[CONF_IDENTITY], psk_id=entry.data[CONF_IDENTITY],
psk=entry.data[CONF_KEY], psk=entry.data[CONF_KEY],
) )
tradfri_data[FACTORY] = factory # Used for async_unload_entry
async def on_hass_stop(event: Event) -> None: async def on_hass_stop(event: Event) -> None:
"""Close connection when hass stops.""" """Close connection when hass stops."""
await factory.shutdown() await factory.shutdown()
# Setup listeners
listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)) listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
api = factory.request api = factory.request
@ -108,19 +123,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try: try:
gateway_info = await api(gateway.get_gateway_info(), timeout=TIMEOUT_API) gateway_info = await api(gateway.get_gateway_info(), timeout=TIMEOUT_API)
devices_commands = await api(gateway.get_devices(), timeout=TIMEOUT_API) devices_commands: Command = await api(
devices = await api(devices_commands, timeout=TIMEOUT_API) gateway.get_devices(), timeout=TIMEOUT_API
groups_commands = await api(gateway.get_groups(), timeout=TIMEOUT_API) )
groups = await api(groups_commands, timeout=TIMEOUT_API) devices: list[Device] = await api(devices_commands, timeout=TIMEOUT_API)
groups_commands: Command = await api(gateway.get_groups(), timeout=TIMEOUT_API)
groups: list[Group] = await api(groups_commands, timeout=TIMEOUT_API)
except PytradfriError as exc: except PytradfriError as exc:
await factory.shutdown() await factory.shutdown()
raise ConfigEntryNotReady from exc raise ConfigEntryNotReady from exc
tradfri_data[KEY_API] = api
tradfri_data[FACTORY] = factory
tradfri_data[DEVICES] = devices
tradfri_data[GROUPS] = groups
dev_reg = await hass.helpers.device_registry.async_get_registry() dev_reg = await hass.helpers.device_registry.async_get_registry()
dev_reg.async_get_or_create( dev_reg.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
@ -133,7 +146,38 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
sw_version=gateway_info.firmware_version, sw_version=gateway_info.firmware_version,
) )
hass.config_entries.async_setup_platforms(entry, PLATFORMS) # Setup the device coordinators
coordinator_data = {
CONF_GATEWAY_ID: gateway,
KEY_API: api,
COORDINATOR_LIST: [],
GROUPS_LIST: [],
}
for device in devices:
coordinator = TradfriDeviceDataUpdateCoordinator(
hass=hass, api=api, device=device
)
await coordinator.async_config_entry_first_refresh()
entry.async_on_unload(
async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available)
)
coordinator_data[COORDINATOR_LIST].append(coordinator)
for group in groups:
group_coordinator = TradfriGroupDataUpdateCoordinator(
hass=hass, api=api, group=group
)
await group_coordinator.async_config_entry_first_refresh()
entry.async_on_unload(
async_dispatcher_connect(
hass, SIGNAL_GW, group_coordinator.set_hub_available
)
)
coordinator_data[GROUPS_LIST].append(group_coordinator)
tradfri_data[COORDINATOR] = coordinator_data
async def async_keep_alive(now: datetime) -> None: async def async_keep_alive(now: datetime) -> None:
if hass.is_stopping: if hass.is_stopping:
@ -152,6 +196,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async_track_time_interval(hass, async_keep_alive, timedelta(seconds=60)) async_track_time_interval(hass, async_keep_alive, timedelta(seconds=60))
) )
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
return True return True

View File

@ -1,29 +1,22 @@
"""Base class for IKEA TRADFRI.""" """Base class for IKEA TRADFRI."""
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
import logging import logging
from typing import Any from typing import Any, cast
from pytradfri.command import Command from pytradfri.command import Command
from pytradfri.device import Device from pytradfri.device import Device
from pytradfri.device.air_purifier import AirPurifier
from pytradfri.device.air_purifier_control import AirPurifierControl
from pytradfri.device.blind import Blind
from pytradfri.device.blind_control import BlindControl
from pytradfri.device.light import Light
from pytradfri.device.light_control import LightControl
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl
from pytradfri.error import PytradfriError from pytradfri.error import PytradfriError
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity import DeviceInfo, Entity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN, SIGNAL_GW from .const import DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -44,102 +37,44 @@ def handle_error(
return wrapper return wrapper
class TradfriBaseClass(Entity): class TradfriBaseEntity(CoordinatorEntity):
"""Base class for IKEA TRADFRI. """Base Tradfri device."""
All devices and groups should ultimately inherit from this class. coordinator: TradfriDeviceDataUpdateCoordinator
"""
_attr_should_poll = False
def __init__( def __init__(
self, self,
device: Device, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
api: Callable[[Command | list[Command]], Any],
) -> None: ) -> None:
"""Initialize a device.""" """Initialize a device."""
self._api = handle_error(api) super().__init__(device_coordinator)
self._attr_name = device.name
self._device: Device = device
self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | AirPurifierControl | None = (
None
)
self._device_data: Socket | Light | Blind | AirPurifier | None = None
self._gateway_id = gateway_id self._gateway_id = gateway_id
async def _async_run_observe(self, cmd: Command) -> None: self._device: Device = device_coordinator.data
"""Run observe in a coroutine."""
try: self._device_id = self._device.id
await self._api(cmd) self._api = handle_error(api)
except PytradfriError as err: self._attr_name = self._device.name
self._attr_available = False
self.async_write_ha_state() self._attr_unique_id = f"{self._gateway_id}-{self._device.id}"
_LOGGER.warning("Observation failed, trying again", exc_info=err)
self._async_start_observe() @abstractmethod
@callback
def _refresh(self) -> None:
"""Refresh device data."""
@callback @callback
def _async_start_observe(self, exc: Exception | None = None) -> None: def _handle_coordinator_update(self) -> None:
"""Start observation of device.""" """
if exc: Handle updated data from the coordinator.
self._attr_available = False
self.async_write_ha_state()
_LOGGER.warning("Observation failed for %s", self._attr_name, exc_info=exc)
cmd = self._device.observe(
callback=self._observe_update,
err_callback=self._async_start_observe,
duration=0,
)
self.hass.async_create_task(self._async_run_observe(cmd))
async def async_added_to_hass(self) -> None: Tests fails without this method.
"""Start thread when added to hass.""" """
self._async_start_observe() self._refresh()
super()._handle_coordinator_update()
@callback
def _observe_update(self, device: Device) -> None:
"""Receive new state data for this device."""
self._refresh(device)
def _refresh(self, device: Device, write_ha: bool = True) -> None:
"""Refresh the device data."""
self._device = device
self._attr_name = device.name
if write_ha:
self.async_write_ha_state()
class TradfriBaseDevice(TradfriBaseClass):
"""Base class for a TRADFRI device.
All devices should inherit from this class.
"""
def __init__(
self,
device: Device,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a device."""
self._attr_available = device.reachable
self._hub_available = True
super().__init__(device, api, gateway_id)
async def async_added_to_hass(self) -> None:
"""Start thread when added to hass."""
# Only devices shall receive SIGNAL_GW
self.async_on_remove(
async_dispatcher_connect(self.hass, SIGNAL_GW, self.set_hub_available)
)
await super().async_added_to_hass()
@callback
def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self._hub_available:
self._hub_available = available
self._refresh(self._device)
@property @property
def device_info(self) -> DeviceInfo: def device_info(self) -> DeviceInfo:
@ -154,10 +89,7 @@ class TradfriBaseDevice(TradfriBaseClass):
via_device=(DOMAIN, self._gateway_id), via_device=(DOMAIN, self._gateway_id),
) )
def _refresh(self, device: Device, write_ha: bool = True) -> None: @property
"""Refresh the device data.""" def available(self) -> bool:
# The base class _refresh cannot be used, because """Return if entity is available."""
# there are devices (group) that do not have .reachable return cast(bool, self._device.reachable) and super().available
# so set _attr_available here and let the base class do the rest.
self._attr_available = device.reachable and self._hub_available
super()._refresh(device, write_ha)

View File

@ -37,3 +37,9 @@ PLATFORMS = [
] ]
TIMEOUT_API = 30 TIMEOUT_API = 30
ATTR_MAX_FAN_STEPS = 49 ATTR_MAX_FAN_STEPS = 49
SCAN_INTERVAL = 60 # Interval for updating the coordinator
COORDINATOR = "coordinator"
COORDINATOR_LIST = "coordinator_list"
GROUPS_LIST = "groups_list"

View File

@ -0,0 +1,145 @@
"""Tradfri DataUpdateCoordinator."""
from __future__ import annotations
from collections.abc import Callable
from datetime import timedelta
import logging
from typing import Any
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.error import RequestError
from pytradfri.group import Group
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import SCAN_INTERVAL
_LOGGER = logging.getLogger(__name__)
class TradfriDeviceDataUpdateCoordinator(DataUpdateCoordinator[Device]):
"""Coordinator to manage data for a specific Tradfri device."""
def __init__(
self,
hass: HomeAssistant,
*,
api: Callable[[Command | list[Command]], Any],
device: Device,
) -> None:
"""Initialize device coordinator."""
self.api = api
self.device = device
self._exception: Exception | None = None
super().__init__(
hass,
_LOGGER,
name=f"Update coordinator for {device}",
update_interval=timedelta(seconds=SCAN_INTERVAL),
)
async def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self.last_update_success:
if not available:
self.last_update_success = False
await self.async_request_refresh()
@callback
def _observe_update(self, device: Device) -> None:
"""Update the coordinator for a device when a change is detected."""
self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval
self.async_set_updated_data(data=device)
@callback
def _exception_callback(self, device: Device, exc: Exception | None = None) -> None:
"""Schedule handling exception.."""
self.hass.async_create_task(self._handle_exception(device=device, exc=exc))
async def _handle_exception(
self, device: Device, exc: Exception | None = None
) -> None:
"""Handle observe exceptions in a coroutine."""
self._exception = (
exc # Store exception so that it gets raised in _async_update_data
)
_LOGGER.debug("Observation failed for %s, trying again", device, exc_info=exc)
self.update_interval = timedelta(
seconds=5
) # Change interval so we get a swift refresh
await self.async_request_refresh()
async def _async_update_data(self) -> Device:
"""Fetch data from the gateway for a specific device."""
try:
if self._exception:
exc = self._exception
self._exception = None # Clear stored exception
raise exc # pylint: disable-msg=raising-bad-type
except RequestError as err:
raise UpdateFailed(
f"Error communicating with API: {err}. Try unplugging and replugging your "
f"IKEA gateway."
) from err
if not self.data or not self.last_update_success: # Start subscription
try:
cmd = self.device.observe(
callback=self._observe_update,
err_callback=self._exception_callback,
duration=0,
)
await self.api(cmd)
except RequestError as exc:
await self._handle_exception(device=self.device, exc=exc)
return self.device
class TradfriGroupDataUpdateCoordinator(DataUpdateCoordinator[Group]):
"""Coordinator to manage data for a specific Tradfri group."""
def __init__(
self,
hass: HomeAssistant,
*,
api: Callable[[Command | list[Command]], Any],
group: Group,
) -> None:
"""Initialize group coordinator."""
self.api = api
self.group = group
self._exception: Exception | None = None
super().__init__(
hass,
_LOGGER,
name=f"Update coordinator for {group}",
update_interval=timedelta(seconds=SCAN_INTERVAL),
)
async def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self.last_update_success:
if not available:
self.last_update_success = False
await self.async_request_refresh()
async def _async_update_data(self) -> Group:
"""Fetch data from the gateway for a specific group."""
self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval
cmd = self.group.update()
try:
await self.api(cmd)
except RequestError as exc:
self.update_interval = timedelta(
seconds=5
) # Change interval so we get a swift refresh
raise UpdateFailed("Unable to update group coordinator") from exc
return self.group

View File

@ -11,8 +11,16 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseEntity
from .const import ATTR_MODEL, CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API from .const import (
ATTR_MODEL,
CONF_GATEWAY_ID,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
KEY_API,
)
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry( async def async_setup_entry(
@ -22,28 +30,42 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri covers based on a config entry.""" """Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = tradfri_data[KEY_API] api = coordinator_data[KEY_API]
devices = tradfri_data[DEVICES]
async_add_entities( async_add_entities(
TradfriCover(dev, api, gateway_id) for dev in devices if dev.has_blind_control TradfriCover(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_blind_control
) )
class TradfriCover(TradfriBaseDevice, CoverEntity): class TradfriCover(TradfriBaseEntity, CoverEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__( def __init__(
self, self,
device: Command, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize a cover.""" """Initialize a switch."""
self._attr_unique_id = f"{gateway_id}-{device.id}" super().__init__(
super().__init__(device, api, gateway_id) device_coordinator=device_coordinator,
self._refresh(device, write_ha=False) api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.blind_control
self._device_data = self._device_control.blinds[0]
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.blind_control.blinds[0]
@property @property
def extra_state_attributes(self) -> dict[str, str] | None: def extra_state_attributes(self) -> dict[str, str] | None:
@ -88,11 +110,3 @@ class TradfriCover(TradfriBaseDevice, CoverEntity):
def is_closed(self) -> bool: def is_closed(self) -> bool:
"""Return if the cover is closed or not.""" """Return if the cover is closed or not."""
return self.current_cover_position == 0 return self.current_cover_position == 0
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the cover data."""
# Caching of BlindControl and cover object
self._device = device
self._device_control = device.blind_control
self._device_data = device.blind_control.blinds[0]
super()._refresh(device, write_ha=write_ha)

View File

@ -16,15 +16,17 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseEntity
from .const import ( from .const import (
ATTR_AUTO, ATTR_AUTO,
ATTR_MAX_FAN_STEPS, ATTR_MAX_FAN_STEPS,
CONF_GATEWAY_ID, CONF_GATEWAY_ID,
DEVICES, COORDINATOR,
COORDINATOR_LIST,
DOMAIN, DOMAIN,
KEY_API, KEY_API,
) )
from .coordinator import TradfriDeviceDataUpdateCoordinator
def _from_fan_percentage(percentage: int) -> int: def _from_fan_percentage(percentage: int) -> int:
@ -44,30 +46,42 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = tradfri_data[KEY_API] api = coordinator_data[KEY_API]
devices = tradfri_data[DEVICES]
async_add_entities( async_add_entities(
TradfriAirPurifierFan(dev, api, gateway_id) TradfriAirPurifierFan(
for dev in devices device_coordinator,
if dev.has_air_purifier_control api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_air_purifier_control
) )
class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity): class TradfriAirPurifierFan(TradfriBaseEntity, FanEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__( def __init__(
self, self,
device: Command, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize a switch.""" """Initialize a switch."""
super().__init__(device, api, gateway_id) super().__init__(
self._attr_unique_id = f"{gateway_id}-{device.id}" device_coordinator=device_coordinator,
self._refresh(device, write_ha=False) api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.air_purifier_control
self._device_data = self._device_control.air_purifiers[0]
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.air_purifier_control.air_purifiers[0]
@property @property
def supported_features(self) -> int: def supported_features(self) -> int:
@ -168,10 +182,3 @@ class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity):
if not self._device_control: if not self._device_control:
return return
await self._api(self._device_control.turn_off()) await self._api(self._device_control.turn_off())
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the purifier data."""
# Caching of air purifier control and purifier object
self._device_control = device.air_purifier_control
self._device_data = device.air_purifier_control.air_purifiers[0]
super()._refresh(device, write_ha=write_ha)

View File

@ -5,6 +5,7 @@ from collections.abc import Callable
from typing import Any, cast from typing import Any, cast
from pytradfri.command import Command from pytradfri.command import Command
from pytradfri.group import Group
from homeassistant.components.light import ( from homeassistant.components.light import (
ATTR_BRIGHTNESS, ATTR_BRIGHTNESS,
@ -19,9 +20,10 @@ from homeassistant.components.light import (
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from .base_class import TradfriBaseClass, TradfriBaseDevice from .base_class import TradfriBaseEntity
from .const import ( from .const import (
ATTR_DIMMER, ATTR_DIMMER,
ATTR_HUE, ATTR_HUE,
@ -29,13 +31,18 @@ from .const import (
ATTR_TRANSITION_TIME, ATTR_TRANSITION_TIME,
CONF_GATEWAY_ID, CONF_GATEWAY_ID,
CONF_IMPORT_GROUPS, CONF_IMPORT_GROUPS,
DEVICES, COORDINATOR,
COORDINATOR_LIST,
DOMAIN, DOMAIN,
GROUPS, GROUPS_LIST,
KEY_API, KEY_API,
SUPPORTED_GROUP_FEATURES, SUPPORTED_GROUP_FEATURES,
SUPPORTED_LIGHT_FEATURES, SUPPORTED_LIGHT_FEATURES,
) )
from .coordinator import (
TradfriDeviceDataUpdateCoordinator,
TradfriGroupDataUpdateCoordinator,
)
async def async_setup_entry( async def async_setup_entry(
@ -45,56 +52,66 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri lights based on a config entry.""" """Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = tradfri_data[KEY_API] api = coordinator_data[KEY_API]
devices = tradfri_data[DEVICES]
entities: list[TradfriBaseClass] = [ entities: list = [
TradfriLight(dev, api, gateway_id) for dev in devices if dev.has_light_control TradfriLight(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_light_control
] ]
if config_entry.data[CONF_IMPORT_GROUPS] and (groups := tradfri_data[GROUPS]):
entities.extend([TradfriGroup(group, api, gateway_id) for group in groups]) if config_entry.data[CONF_IMPORT_GROUPS] and (
group_coordinators := coordinator_data[GROUPS_LIST]
):
entities.extend(
[
TradfriGroup(group_coordinator, api, gateway_id)
for group_coordinator in group_coordinators
]
)
async_add_entities(entities) async_add_entities(entities)
class TradfriGroup(TradfriBaseClass, LightEntity): class TradfriGroup(CoordinatorEntity, LightEntity):
"""The platform class for light groups required by hass.""" """The platform class for light groups required by hass."""
_attr_supported_features = SUPPORTED_GROUP_FEATURES _attr_supported_features = SUPPORTED_GROUP_FEATURES
def __init__( def __init__(
self, self,
device: Command, group_coordinator: TradfriGroupDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize a Group.""" """Initialize a Group."""
super().__init__(device, api, gateway_id) super().__init__(coordinator=group_coordinator)
self._attr_unique_id = f"group-{gateway_id}-{device.id}" self._group: Group = self.coordinator.data
self._attr_should_poll = True
self._refresh(device, write_ha=False)
async def async_update(self) -> None: self._api = api
"""Fetch new state data for the group. self._attr_unique_id = f"group-{gateway_id}-{self._group.id}"
This method is required for groups to update properly.
"""
await self._api(self._device.update())
@property @property
def is_on(self) -> bool: def is_on(self) -> bool:
"""Return true if group lights are on.""" """Return true if group lights are on."""
return cast(bool, self._device.state) return cast(bool, self._group.state)
@property @property
def brightness(self) -> int | None: def brightness(self) -> int | None:
"""Return the brightness of the group lights.""" """Return the brightness of the group lights."""
return cast(int, self._device.dimmer) return cast(int, self._group.dimmer)
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Instruct the group lights to turn off.""" """Instruct the group lights to turn off."""
await self._api(self._device.set_state(0)) await self._api(self._group.set_state(0))
await self.coordinator.async_request_refresh()
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Instruct the group lights to turn on, or dim.""" """Instruct the group lights to turn on, or dim."""
@ -106,39 +123,53 @@ class TradfriGroup(TradfriBaseClass, LightEntity):
if kwargs[ATTR_BRIGHTNESS] == 255: if kwargs[ATTR_BRIGHTNESS] == 255:
kwargs[ATTR_BRIGHTNESS] = 254 kwargs[ATTR_BRIGHTNESS] = 254
await self._api(self._device.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys)) await self._api(self._group.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys))
else: else:
await self._api(self._device.set_state(1)) await self._api(self._group.set_state(1))
await self.coordinator.async_request_refresh()
class TradfriLight(TradfriBaseDevice, LightEntity): class TradfriLight(TradfriBaseEntity, LightEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__( def __init__(
self, self,
device: Command, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize a Light.""" """Initialize a Light."""
super().__init__(device, api, gateway_id) super().__init__(
self._attr_unique_id = f"light-{gateway_id}-{device.id}" device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.light_control
self._device_data = self._device_control.lights[0]
self._attr_unique_id = f"light-{gateway_id}-{self._device_id}"
self._hs_color = None self._hs_color = None
# Calculate supported features # Calculate supported features
_features = SUPPORTED_LIGHT_FEATURES _features = SUPPORTED_LIGHT_FEATURES
if device.light_control.can_set_dimmer: if self._device.light_control.can_set_dimmer:
_features |= SUPPORT_BRIGHTNESS _features |= SUPPORT_BRIGHTNESS
if device.light_control.can_set_color: if self._device.light_control.can_set_color:
_features |= SUPPORT_COLOR | SUPPORT_COLOR_TEMP _features |= SUPPORT_COLOR | SUPPORT_COLOR_TEMP
if device.light_control.can_set_temp: if self._device.light_control.can_set_temp:
_features |= SUPPORT_COLOR_TEMP _features |= SUPPORT_COLOR_TEMP
self._attr_supported_features = _features self._attr_supported_features = _features
self._refresh(device, write_ha=False)
if self._device_control: if self._device_control:
self._attr_min_mireds = self._device_control.min_mireds self._attr_min_mireds = self._device_control.min_mireds
self._attr_max_mireds = self._device_control.max_mireds self._attr_max_mireds = self._device_control.max_mireds
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.light_control.lights[0]
@property @property
def is_on(self) -> bool: def is_on(self) -> bool:
"""Return true if light is on.""" """Return true if light is on."""
@ -268,10 +299,3 @@ class TradfriLight(TradfriBaseDevice, LightEntity):
await self._api(temp_command) await self._api(temp_command)
if command is not None: if command is not None:
await self._api(command) await self._api(command)
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the light data."""
# Caching of LightControl and light object
self._device_control = device.light_control
self._device_data = device.light_control.lights[0]
super()._refresh(device, write_ha=write_ha)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import Any
from pytradfri.command import Command from pytradfri.command import Command
@ -12,8 +12,9 @@ from homeassistant.const import PERCENTAGE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseEntity
from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry( async def async_setup_entry(
@ -23,24 +24,27 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up a Tradfri config entry.""" """Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = tradfri_data[KEY_API] api = coordinator_data[KEY_API]
devices = tradfri_data[DEVICES]
async_add_entities( async_add_entities(
TradfriSensor(dev, api, gateway_id) TradfriSensor(
for dev in devices device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if ( if (
not dev.has_light_control not device_coordinator.device.has_light_control
and not dev.has_socket_control and not device_coordinator.device.has_socket_control
and not dev.has_blind_control and not device_coordinator.device.has_blind_control
and not dev.has_signal_repeater_control and not device_coordinator.device.has_signal_repeater_control
and not dev.has_air_purifier_control and not device_coordinator.device.has_air_purifier_control
) )
) )
class TradfriSensor(TradfriBaseDevice, SensorEntity): class TradfriSensor(TradfriBaseEntity, SensorEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
_attr_device_class = SensorDeviceClass.BATTERY _attr_device_class = SensorDeviceClass.BATTERY
@ -48,17 +52,19 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity):
def __init__( def __init__(
self, self,
device: Command, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize the device.""" """Initialize a switch."""
super().__init__(device, api, gateway_id) super().__init__(
self._attr_unique_id = f"{gateway_id}-{device.id}" device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
@property self._refresh() # Set initial state
def native_value(self) -> int | None:
"""Return the current state of the device.""" def _refresh(self) -> None:
if not self._device: """Refresh the device."""
return None self._attr_native_value = self.coordinator.data.device_info.battery_level
return cast(int, self._device.device_info.battery_level)

View File

@ -11,8 +11,9 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseEntity
from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry( async def async_setup_entry(
@ -22,35 +23,42 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = tradfri_data[KEY_API] api = coordinator_data[KEY_API]
devices = tradfri_data[DEVICES]
async_add_entities( async_add_entities(
TradfriSwitch(dev, api, gateway_id) for dev in devices if dev.has_socket_control TradfriSwitch(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_socket_control
) )
class TradfriSwitch(TradfriBaseDevice, SwitchEntity): class TradfriSwitch(TradfriBaseEntity, SwitchEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__( def __init__(
self, self,
device: Command, device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any], api: Callable[[Command | list[Command]], Any],
gateway_id: str, gateway_id: str,
) -> None: ) -> None:
"""Initialize a switch.""" """Initialize a switch."""
super().__init__(device, api, gateway_id) super().__init__(
self._attr_unique_id = f"{gateway_id}-{device.id}" device_coordinator=device_coordinator,
self._refresh(device, write_ha=False) api=api,
gateway_id=gateway_id,
)
def _refresh(self, device: Command, write_ha: bool = True) -> None: self._device_control = self._device.socket_control
"""Refresh the switch data.""" self._device_data = self._device_control.sockets[0]
# Caching of switch control and switch object
self._device_control = device.socket_control def _refresh(self) -> None:
self._device_data = device.socket_control.sockets[0] """Refresh the device."""
super()._refresh(device, write_ha=write_ha) self._device_data = self.coordinator.data.socket_control.sockets[0]
@property @property
def is_on(self) -> bool: def is_on(self) -> bool:

View File

@ -22,3 +22,5 @@ async def setup_integration(hass):
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
return entry

View File

@ -121,7 +121,6 @@ async def test_set_percentage(
"""Test setting speed of a fan.""" """Test setting speed of a fan."""
# Note pytradfri style, not hass. Values not really important. # Note pytradfri style, not hass. Values not really important.
initial_state = {"percentage": 10, "fan_speed": 3} initial_state = {"percentage": 10, "fan_speed": 3}
# Setup the gateway with a mock fan. # Setup the gateway with a mock fan.
fan = mock_fan(test_state=initial_state, device_number=0) fan = mock_fan(test_state=initial_state, device_number=0)
mock_gateway.mock_devices.append(fan) mock_gateway.mock_devices.append(fan)

View File

@ -317,6 +317,7 @@ def mock_group(test_state=None, group_number=0):
_mock_group = Mock(member_ids=[], observe=Mock(), **state) _mock_group = Mock(member_ids=[], observe=Mock(), **state)
_mock_group.name = f"tradfri_group_{group_number}" _mock_group.name = f"tradfri_group_{group_number}"
_mock_group.id = group_number
return _mock_group return _mock_group
@ -327,11 +328,11 @@ async def test_group(hass, mock_gateway, mock_api_factory):
mock_gateway.mock_groups.append(mock_group(state, 1)) mock_gateway.mock_groups.append(mock_group(state, 1))
await setup_integration(hass) await setup_integration(hass)
group = hass.states.get("light.tradfri_group_0") group = hass.states.get("light.tradfri_group_mock_gateway_id_0")
assert group is not None assert group is not None
assert group.state == "off" assert group.state == "off"
group = hass.states.get("light.tradfri_group_1") group = hass.states.get("light.tradfri_group_mock_gateway_id_1")
assert group is not None assert group is not None
assert group.state == "on" assert group.state == "on"
assert group.attributes["brightness"] == 100 assert group.attributes["brightness"] == 100
@ -348,19 +349,26 @@ async def test_group_turn_on(hass, mock_gateway, mock_api_factory):
await setup_integration(hass) await setup_integration(hass)
# Use the turn_off service call to change the light state. # Use the turn_off service call to change the light state.
await hass.services.async_call(
"light", "turn_on", {"entity_id": "light.tradfri_group_0"}, blocking=True
)
await hass.services.async_call( await hass.services.async_call(
"light", "light",
"turn_on", "turn_on",
{"entity_id": "light.tradfri_group_1", "brightness": 100}, {"entity_id": "light.tradfri_group_mock_gateway_id_0"},
blocking=True, blocking=True,
) )
await hass.services.async_call( await hass.services.async_call(
"light", "light",
"turn_on", "turn_on",
{"entity_id": "light.tradfri_group_2", "brightness": 100, "transition": 1}, {"entity_id": "light.tradfri_group_mock_gateway_id_1", "brightness": 100},
blocking=True,
)
await hass.services.async_call(
"light",
"turn_on",
{
"entity_id": "light.tradfri_group_mock_gateway_id_2",
"brightness": 100,
"transition": 1,
},
blocking=True, blocking=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -378,7 +386,10 @@ async def test_group_turn_off(hass, mock_gateway, mock_api_factory):
# Use the turn_off service call to change the light state. # Use the turn_off service call to change the light state.
await hass.services.async_call( await hass.services.async_call(
"light", "turn_off", {"entity_id": "light.tradfri_group_0"}, blocking=True "light",
"turn_off",
{"entity_id": "light.tradfri_group_mock_gateway_id_0"},
blocking=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()