diff --git a/homeassistant/components/tplink/camera.py b/homeassistant/components/tplink/camera.py index e1db7254428..61a08887f5f 100644 --- a/homeassistant/components/tplink/camera.py +++ b/homeassistant/components/tplink/camera.py @@ -7,8 +7,7 @@ import time from aiohttp import web from haffmpeg.camera import CameraMjpeg -from kasa import Credentials, Device, Module, StreamResolution -from kasa.smartcam.modules import Camera as CameraModule +from kasa import Device, Module, StreamResolution from homeassistant.components import ffmpeg, stream from homeassistant.components.camera import ( @@ -24,10 +23,14 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from . import TPLinkConfigEntry, legacy_device_id from .const import CONF_CAMERA_CREDENTIALS from .coordinator import TPLinkDataUpdateCoordinator -from .entity import CoordinatedTPLinkEntity, TPLinkModuleEntityDescription +from .entity import CoordinatedTPLinkModuleEntity, TPLinkModuleEntityDescription _LOGGER = logging.getLogger(__name__) +# Coordinator is used to centralize the data updates +# For actions the integration handles locking of concurrent device request +PARALLEL_UPDATES = 0 + @dataclass(frozen=True, kw_only=True) class TPLinkCameraEntityDescription( @@ -36,15 +39,18 @@ class TPLinkCameraEntityDescription( """Base class for camera entity description.""" -# Coordinator is used to centralize the data updates -# For actions the integration handles locking of concurrent device request -PARALLEL_UPDATES = 0 - CAMERA_DESCRIPTIONS: tuple[TPLinkCameraEntityDescription, ...] = ( TPLinkCameraEntityDescription( key="live_view", translation_key="live_view", available_fn=lambda dev: dev.is_on, + exists_fn=lambda dev, entry: ( + (rtd := entry.runtime_data) is not None + and rtd.live_view is True + and (cam_creds := rtd.camera_credentials) is not None + and (cm := dev.modules.get(Module.Camera)) is not None + and cm.stream_rtsp_url(cam_creds) is not None + ), ), ) @@ -58,26 +64,28 @@ async def async_setup_entry( data = config_entry.runtime_data parent_coordinator = data.parent_coordinator device = parent_coordinator.device - camera_credentials = data.camera_credentials - live_view = data.live_view - ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass) - async_add_entities( - TPLinkCameraEntity( - device, - parent_coordinator, - description, - camera_module=camera_module, - parent=None, - ffmpeg_manager=ffmpeg_manager, - camera_credentials=camera_credentials, + known_child_device_ids: set[str] = set() + first_check = True + + def _check_device() -> None: + entities = CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkCameraEntity, + descriptions=CAMERA_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids, + first_check=first_check, ) - for description in CAMERA_DESCRIPTIONS - if (camera_module := device.modules.get(Module.Camera)) and live_view - ) + async_add_entities(entities) + + _check_device() + first_check = False + config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device)) -class TPLinkCameraEntity(CoordinatedTPLinkEntity, Camera): +class TPLinkCameraEntity(CoordinatedTPLinkModuleEntity, Camera): """Representation of a TPLink camera.""" IMAGE_INTERVAL = 5 * 60 @@ -86,30 +94,30 @@ class TPLinkCameraEntity(CoordinatedTPLinkEntity, Camera): entity_description: TPLinkCameraEntityDescription + _ffmpeg_manager: ffmpeg.FFmpegManager + def __init__( self, device: Device, coordinator: TPLinkDataUpdateCoordinator, description: TPLinkCameraEntityDescription, *, - camera_module: CameraModule, parent: Device | None = None, - ffmpeg_manager: ffmpeg.FFmpegManager, - camera_credentials: Credentials | None, ) -> None: """Initialize a TPlink camera.""" - self.entity_description = description - self._camera_module = camera_module - self._video_url = camera_module.stream_rtsp_url( - camera_credentials, stream_resolution=StreamResolution.SD + super().__init__(device, coordinator, description=description, parent=parent) + Camera.__init__(self) + + self._camera_module = device.modules[Module.Camera] + self._camera_credentials = ( + coordinator.config_entry.runtime_data.camera_credentials + ) + self._video_url = self._camera_module.stream_rtsp_url( + self._camera_credentials, stream_resolution=StreamResolution.SD ) self._image: bytes | None = None - super().__init__(device, coordinator, parent=parent) - Camera.__init__(self) - self._ffmpeg_manager = ffmpeg_manager self._image_lock = asyncio.Lock() self._last_update: float = 0 - self._camera_credentials = camera_credentials self._can_stream = True self._http_mpeg_stream_running = False @@ -117,6 +125,12 @@ class TPLinkCameraEntity(CoordinatedTPLinkEntity, Camera): """Return unique ID for the entity.""" return f"{legacy_device_id(self._device)}-{self.entity_description.key}" + async def async_added_to_hass(self) -> None: + """Call update attributes after the device is added to the platform.""" + await super().async_added_to_hass() + + self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass) + @callback def _async_update_attrs(self) -> bool: """Update the entity's attributes.""" diff --git a/homeassistant/components/tplink/climate.py b/homeassistant/components/tplink/climate.py index e8b7336f391..a7dd865e7bb 100644 --- a/homeassistant/components/tplink/climate.py +++ b/homeassistant/components/tplink/climate.py @@ -2,15 +2,17 @@ from __future__ import annotations +from dataclasses import dataclass import logging from typing import Any, cast -from kasa import Device, DeviceType +from kasa import Device from kasa.smart.modules.temperaturecontrol import ThermostatState from homeassistant.components.climate import ( ATTR_TEMPERATURE, ClimateEntity, + ClimateEntityDescription, ClimateEntityFeature, HVACAction, HVACMode, @@ -23,7 +25,11 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from . import TPLinkConfigEntry from .const import DOMAIN, UNIT_MAPPING from .coordinator import TPLinkDataUpdateCoordinator -from .entity import CoordinatedTPLinkEntity, async_refresh_after +from .entity import ( + CoordinatedTPLinkModuleEntity, + TPLinkModuleEntityDescription, + async_refresh_after, +) # Coordinator is used to centralize the data updates # For actions the integration handles locking of concurrent device request @@ -40,6 +46,21 @@ STATE_TO_ACTION = { _LOGGER = logging.getLogger(__name__) +@dataclass(frozen=True, kw_only=True) +class TPLinkClimateEntityDescription( + ClimateEntityDescription, TPLinkModuleEntityDescription +): + """Base class for climate entity description.""" + + +CLIMATE_DESCRIPTIONS: tuple[TPLinkClimateEntityDescription, ...] = ( + TPLinkClimateEntityDescription( + key="climate", + exists_fn=lambda dev, _: dev.device_type is Device.Type.Thermostat, + ), +) + + async def async_setup_entry( hass: HomeAssistant, config_entry: TPLinkConfigEntry, @@ -50,15 +71,27 @@ async def async_setup_entry( parent_coordinator = data.parent_coordinator device = parent_coordinator.device - # As there are no standalone thermostats, we just iterate over the children. - async_add_entities( - TPLinkClimateEntity(child, parent_coordinator, parent=device) - for child in device.children - if child.device_type is DeviceType.Thermostat - ) + known_child_device_ids: set[str] = set() + first_check = True + + def _check_device() -> None: + entities = CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkClimateEntity, + descriptions=CLIMATE_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids, + first_check=first_check, + ) + async_add_entities(entities) + + _check_device() + first_check = False + config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device)) -class TPLinkClimateEntity(CoordinatedTPLinkEntity, ClimateEntity): +class TPLinkClimateEntity(CoordinatedTPLinkModuleEntity, ClimateEntity): """Representation of a TPLink thermostat.""" _attr_name = None @@ -70,16 +103,20 @@ class TPLinkClimateEntity(CoordinatedTPLinkEntity, ClimateEntity): _attr_hvac_modes = [HVACMode.HEAT, HVACMode.OFF] _attr_precision = PRECISION_TENTHS + entity_description: TPLinkClimateEntityDescription + # This disables the warning for async_turn_{on,off}, can be removed later. def __init__( self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkClimateEntityDescription, *, parent: Device, ) -> None: """Initialize the climate entity.""" + super().__init__(device, coordinator, description, parent=parent) self._state_feature = device.features["state"] self._mode_feature = device.features["thermostat_mode"] self._temp_feature = device.features["temperature"] @@ -89,8 +126,6 @@ class TPLinkClimateEntity(CoordinatedTPLinkEntity, ClimateEntity): self._attr_max_temp = self._target_feature.maximum_value self._attr_temperature_unit = UNIT_MAPPING[cast(str, self._temp_feature.unit)] - super().__init__(device, coordinator, parent=parent) - @async_refresh_after async def async_set_temperature(self, **kwargs: Any) -> None: """Set target temperature.""" diff --git a/homeassistant/components/tplink/entity.py b/homeassistant/components/tplink/entity.py index 178c8bfdd3d..e7c3600acc2 100644 --- a/homeassistant/components/tplink/entity.py +++ b/homeassistant/components/tplink/entity.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Awaitable, Callable, Coroutine, Mapping +from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping from dataclasses import dataclass, replace import logging from typing import Any, Concatenate @@ -35,7 +35,7 @@ from .const import ( DOMAIN, PRIMARY_STATE_ID, ) -from .coordinator import TPLinkDataUpdateCoordinator +from .coordinator import TPLinkConfigEntry, TPLinkDataUpdateCoordinator from .deprecate import DeprecatedInfo, async_check_create_deprecated _LOGGER = logging.getLogger(__name__) @@ -85,7 +85,7 @@ LEGACY_KEY_MAPPING = { @dataclass(frozen=True, kw_only=True) -class TPLinkFeatureEntityDescription(EntityDescription): +class TPLinkEntityDescription(EntityDescription): """Base class for a TPLink feature based entity description.""" deprecated_info: DeprecatedInfo | None = None @@ -93,11 +93,15 @@ class TPLinkFeatureEntityDescription(EntityDescription): @dataclass(frozen=True, kw_only=True) -class TPLinkModuleEntityDescription(EntityDescription): +class TPLinkFeatureEntityDescription(TPLinkEntityDescription): + """Base class for a TPLink feature based entity description.""" + + +@dataclass(frozen=True, kw_only=True) +class TPLinkModuleEntityDescription(TPLinkEntityDescription): """Base class for a TPLink module based entity description.""" - deprecated_info: DeprecatedInfo | None = None - available_fn: Callable[[Device], bool] = lambda _: True + exists_fn: Callable[[Device, TPLinkConfigEntry], bool] def async_refresh_after[_T: CoordinatedTPLinkEntity, **_P]( @@ -151,13 +155,16 @@ class CoordinatedTPLinkEntity(CoordinatorEntity[TPLinkDataUpdateCoordinator], AB self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkEntityDescription, *, feature: Feature | None = None, parent: Device | None = None, ) -> None: """Initialize the entity.""" super().__init__(coordinator) + self.entity_description = description self._device: Device = device + self._parent = parent self._feature = feature registry_device = device @@ -209,6 +216,10 @@ class CoordinatedTPLinkEntity(CoordinatorEntity[TPLinkDataUpdateCoordinator], AB hw_version=registry_device.hw_info["hw_ver"], ) + # child device entities will link via_device unless they were created + # above on the parent. Otherwise the mac connections is set which or + # for wall switches like the ks240 will mean the child and parent devices + # are treated as one device. if ( parent is not None and parent != registry_device @@ -222,12 +233,16 @@ class CoordinatedTPLinkEntity(CoordinatorEntity[TPLinkDataUpdateCoordinator], AB self._attr_unique_id = self._get_unique_id() - self._async_call_update_attrs() - def _get_unique_id(self) -> str: """Return unique ID for the entity.""" return legacy_device_id(self._device) + async def async_added_to_hass(self) -> None: + """Call update attributes after the device is added to the platform.""" + await super().async_added_to_hass() + + self._async_call_update_attrs() + @abstractmethod @callback def _async_update_attrs(self) -> bool: @@ -276,14 +291,19 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC): self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkFeatureEntityDescription, *, feature: Feature, - description: TPLinkFeatureEntityDescription, parent: Device | None = None, ) -> None: """Initialize the entity.""" - self.entity_description = description - super().__init__(device, coordinator, parent=parent, feature=feature) + super().__init__( + device, coordinator, description, parent=parent, feature=feature + ) + + # Update the feature attributes so the registered entity contains + # values like unit_of_measurement and suggested_display_precision + self._async_call_update_attrs() def _get_unique_id(self) -> str: """Return unique ID for the entity.""" @@ -456,29 +476,9 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC): ) ) - # Remove any device ids removed via the coordinator so they can be re-added - for removed_child_id in coordinator.removed_child_device_ids: - _LOGGER.debug( - "Removing %s from known %s child ids for device %s" - "as it has been removed by the coordinator", - removed_child_id, - entity_class.__name__, - device.host, - ) - known_child_device_ids.discard(removed_child_id) - - current_child_devices = {child.device_id: child for child in device.children} - current_child_device_ids = set(current_child_devices.keys()) - new_child_device_ids = current_child_device_ids - known_child_device_ids - children = [] - - if new_child_device_ids: - children = [ - child - for child_id, child in current_child_devices.items() - if child_id in new_child_device_ids - ] - known_child_device_ids.update(new_child_device_ids) + children = _get_new_children( + device, coordinator, known_child_device_ids, entity_class.__name__ + ) if children: _LOGGER.debug( @@ -487,6 +487,7 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC): len(children), device.host, ) + for child in children: child_coordinator = coordinator.get_child_coordinator(child) @@ -509,3 +510,170 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC): entities.extend(child_entities) return entities + + +class CoordinatedTPLinkModuleEntity(CoordinatedTPLinkEntity, ABC): + """Common base class for all coordinated tplink module based entities.""" + + entity_description: TPLinkModuleEntityDescription + + def __init__( + self, + device: Device, + coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkModuleEntityDescription, + *, + parent: Device | None = None, + ) -> None: + """Initialize the entity.""" + super().__init__(device, coordinator, description, parent=parent) + + # Module based entities will usually be 1 per device so they will use + # the device name. If there are multiple module entities based entities + # the description should have a translation key. + # HA logic is to name entities based on the following logic: + # _attr_name > translation.name > description.name + if not description.translation_key: + if parent is None or parent.device_type is Device.Type.Hub: + self._attr_name = None + else: + self._attr_name = get_device_name(device) + + @classmethod + def _entities_for_device[ + _E: CoordinatedTPLinkModuleEntity, + _D: TPLinkModuleEntityDescription, + ]( + cls, + hass: HomeAssistant, + device: Device, + coordinator: TPLinkDataUpdateCoordinator, + *, + entity_class: type[_E], + descriptions: Iterable[_D], + parent: Device | None = None, + ) -> list[_E]: + """Return a list of entities to add.""" + entities: list[_E] = [ + entity_class( + device, + coordinator, + description=description, + parent=parent, + ) + for description in descriptions + if description.exists_fn(device, coordinator.config_entry) + ] + return entities + + @classmethod + def entities_for_device_and_its_children[ + _E: CoordinatedTPLinkModuleEntity, + _D: TPLinkModuleEntityDescription, + ]( + cls, + hass: HomeAssistant, + device: Device, + coordinator: TPLinkDataUpdateCoordinator, + *, + entity_class: type[_E], + descriptions: Iterable[_D], + known_child_device_ids: set[str], + first_check: bool, + ) -> list[_E]: + """Create entities for device and its children. + + This is a helper that calls *_entities_for_device* for the device and its children. + """ + entities: list[_E] = [] + + # Add parent entities before children so via_device id works. + # Only add the parent entities the first time + if first_check: + entities.extend( + cls._entities_for_device( + hass, + device, + coordinator=coordinator, + entity_class=entity_class, + descriptions=descriptions, + ) + ) + has_parent_entities = bool(entities) + + children = _get_new_children( + device, coordinator, known_child_device_ids, entity_class.__name__ + ) + + if children: + _LOGGER.debug( + "Getting %s entities for %s child devices on device %s", + entity_class.__name__, + len(children), + device.host, + ) + for child in children: + child_coordinator = coordinator.get_child_coordinator(child) + + child_entities: list[_E] = cls._entities_for_device( + hass, + child, + coordinator=child_coordinator, + entity_class=entity_class, + descriptions=descriptions, + parent=device, + ) + _LOGGER.debug( + "Device %s, found %s child %s entities for child id %s", + device.host, + len(entities), + entity_class.__name__, + child.device_id, + ) + entities.extend(child_entities) + + if first_check and entities and not has_parent_entities: + # Get or create the parent device for via_device. + # This is a timing factor in case this platform is loaded before + # other platforms that will have entities on the parent. Eventually + # those other platforms will update the parent with full DeviceInfo + device_registry = dr.async_get(hass) + device_registry.async_get_or_create( + config_entry_id=coordinator.config_entry.entry_id, + identifiers={(DOMAIN, device.device_id)}, + ) + return entities + + +def _get_new_children( + device: Device, + coordinator: TPLinkDataUpdateCoordinator, + known_child_device_ids: set[str], + entity_class_name: str, +) -> list[Device]: + """Get a list of children to check for entity creation.""" + # Remove any device ids removed via the coordinator so they can be re-added + for removed_child_id in coordinator.removed_child_device_ids: + _LOGGER.debug( + "Removing %s from known %s child ids for device %s" + "as it has been removed by the coordinator", + removed_child_id, + entity_class_name, + device.host, + ) + known_child_device_ids.discard(removed_child_id) + + current_child_devices = {child.device_id: child for child in device.children} + current_child_device_ids = set(current_child_devices.keys()) + new_child_device_ids = current_child_device_ids - known_child_device_ids + children = [] + + if new_child_device_ids: + children = [ + child + for child_id, child in current_child_devices.items() + if child_id in new_child_device_ids + ] + known_child_device_ids.update(new_child_device_ids) + return children + return [] diff --git a/homeassistant/components/tplink/fan.py b/homeassistant/components/tplink/fan.py index 92cf049c11a..cb17955fbcb 100644 --- a/homeassistant/components/tplink/fan.py +++ b/homeassistant/components/tplink/fan.py @@ -1,13 +1,17 @@ """Support for TPLink Fan devices.""" +from dataclasses import dataclass import logging import math from typing import Any from kasa import Device, Module -from kasa.interfaces import Fan as FanInterface -from homeassistant.components.fan import FanEntity, FanEntityFeature +from homeassistant.components.fan import ( + FanEntity, + FanEntityDescription, + FanEntityFeature, +) from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util.percentage import ( @@ -18,7 +22,11 @@ from homeassistant.util.scaling import int_states_in_range from . import TPLinkConfigEntry from .coordinator import TPLinkDataUpdateCoordinator -from .entity import CoordinatedTPLinkEntity, async_refresh_after +from .entity import ( + CoordinatedTPLinkModuleEntity, + TPLinkModuleEntityDescription, + async_refresh_after, +) # Coordinator is used to centralize the data updates # For actions the integration handles locking of concurrent device request @@ -27,6 +35,19 @@ PARALLEL_UPDATES = 0 _LOGGER = logging.getLogger(__name__) +@dataclass(frozen=True, kw_only=True) +class TPLinkFanEntityDescription(FanEntityDescription, TPLinkModuleEntityDescription): + """Base class for fan entity description.""" + + +FAN_DESCRIPTIONS: tuple[TPLinkFanEntityDescription, ...] = ( + TPLinkFanEntityDescription( + key="fan", + exists_fn=lambda dev, _: Module.Fan in dev.modules, + ), +) + + async def async_setup_entry( hass: HomeAssistant, config_entry: TPLinkConfigEntry, @@ -36,30 +57,31 @@ async def async_setup_entry( data = config_entry.runtime_data parent_coordinator = data.parent_coordinator device = parent_coordinator.device - entities: list[CoordinatedTPLinkEntity] = [] - if Module.Fan in device.modules: - entities.append( - TPLinkFanEntity( - device, parent_coordinator, fan_module=device.modules[Module.Fan] - ) + + known_child_device_ids: set[str] = set() + first_check = True + + def _check_device() -> None: + entities = CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkFanEntity, + descriptions=FAN_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids, + first_check=first_check, ) - entities.extend( - TPLinkFanEntity( - child, - parent_coordinator, - fan_module=child.modules[Module.Fan], - parent=device, - ) - for child in device.children - if Module.Fan in child.modules - ) - async_add_entities(entities) + async_add_entities(entities) + + _check_device() + first_check = False + config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device)) SPEED_RANGE = (1, 4) # off is not included -class TPLinkFanEntity(CoordinatedTPLinkEntity, FanEntity): +class TPLinkFanEntity(CoordinatedTPLinkModuleEntity, FanEntity): """Representation of a fan for a TPLink Fan device.""" _attr_speed_count = int_states_in_range(SPEED_RANGE) @@ -69,19 +91,19 @@ class TPLinkFanEntity(CoordinatedTPLinkEntity, FanEntity): | FanEntityFeature.TURN_ON ) + entity_description: TPLinkFanEntityDescription + def __init__( self, device: Device, coordinator: TPLinkDataUpdateCoordinator, - fan_module: FanInterface, + description: TPLinkFanEntityDescription, + *, parent: Device | None = None, ) -> None: """Initialize the fan.""" - self.fan_module = fan_module - # If _attr_name is None the entity name will be the device name - self._attr_name = None if parent is None else device.alias - - super().__init__(device, coordinator, parent=parent) + super().__init__(device, coordinator, description, parent=parent) + self.fan_module = device.modules[Module.Fan] @async_refresh_after async def async_turn_on( diff --git a/homeassistant/components/tplink/light.py b/homeassistant/components/tplink/light.py index 731ee919c98..bc4d792b3f8 100644 --- a/homeassistant/components/tplink/light.py +++ b/homeassistant/components/tplink/light.py @@ -3,11 +3,12 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass import logging from typing import Any from kasa import Device, DeviceType, KasaException, LightState, Module -from kasa.interfaces import Light, LightEffect +from kasa.interfaces import LightEffect from kasa.iot import IotDevice import voluptuous as vol @@ -20,12 +21,12 @@ from homeassistant.components.light import ( EFFECT_OFF, ColorMode, LightEntity, + LightEntityDescription, LightEntityFeature, filter_supported_color_modes, ) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import entity_platform import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import VolDictType @@ -33,7 +34,11 @@ from homeassistant.helpers.typing import VolDictType from . import TPLinkConfigEntry, legacy_device_id from .const import DOMAIN from .coordinator import TPLinkDataUpdateCoordinator -from .entity import CoordinatedTPLinkEntity, async_refresh_after +from .entity import ( + CoordinatedTPLinkModuleEntity, + TPLinkModuleEntityDescription, + async_refresh_after, +) # Coordinator is used to centralize the data updates # For actions the integration handles locking of concurrent device request @@ -136,75 +141,93 @@ def _async_build_base_effect( } +@dataclass(frozen=True, kw_only=True) +class TPLinkLightEntityDescription( + LightEntityDescription, TPLinkModuleEntityDescription +): + """Base class for tplink light entity description.""" + + +LIGHT_DESCRIPTIONS: tuple[TPLinkLightEntityDescription, ...] = ( + TPLinkLightEntityDescription( + key="light", + exists_fn=lambda dev, _: Module.Light in dev.modules + and Module.LightEffect not in dev.modules, + ), +) + +LIGHT_EFFECT_DESCRIPTIONS: tuple[TPLinkLightEntityDescription, ...] = ( + TPLinkLightEntityDescription( + key="light_effect", + exists_fn=lambda dev, _: Module.Light in dev.modules + and Module.LightEffect in dev.modules, + ), +) + + async def async_setup_entry( hass: HomeAssistant, config_entry: TPLinkConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: - """Set up switches.""" + """Set up lights.""" data = config_entry.runtime_data parent_coordinator = data.parent_coordinator device = parent_coordinator.device - entities: list[TPLinkLightEntity | TPLinkLightEffectEntity] = [] - if effect_module := device.modules.get(Module.LightEffect): - entities.append( - TPLinkLightEffectEntity( - device, - parent_coordinator, - light_module=device.modules[Module.Light], - effect_module=effect_module, + + known_child_device_ids_light: set[str] = set() + known_child_device_ids_light_effect: set[str] = set() + first_check = True + + def _check_device() -> None: + entities = CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkLightEntity, + descriptions=LIGHT_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids_light, + first_check=first_check, + ) + entities.extend( + CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkLightEffectEntity, + descriptions=LIGHT_EFFECT_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids_light_effect, + first_check=first_check, ) ) - if effect_module.has_custom_effects: - platform = entity_platform.async_get_current_platform() - platform.async_register_entity_service( - SERVICE_RANDOM_EFFECT, - RANDOM_EFFECT_DICT, - "async_set_random_effect", - ) - platform.async_register_entity_service( - SERVICE_SEQUENCE_EFFECT, - SEQUENCE_EFFECT_DICT, - "async_set_sequence_effect", - ) - elif Module.Light in device.modules: - entities.append( - TPLinkLightEntity( - device, parent_coordinator, light_module=device.modules[Module.Light] - ) - ) - entities.extend( - TPLinkLightEntity( - child, - parent_coordinator, - light_module=child.modules[Module.Light], - parent=device, - ) - for child in device.children - if Module.Light in child.modules - ) - async_add_entities(entities) + async_add_entities(entities) + + _check_device() + first_check = False + config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device)) -class TPLinkLightEntity(CoordinatedTPLinkEntity, LightEntity): +class TPLinkLightEntity(CoordinatedTPLinkModuleEntity, LightEntity): """Representation of a TPLink Smart Bulb.""" _attr_supported_features = LightEntityFeature.TRANSITION _fixed_color_mode: ColorMode | None = None + entity_description: TPLinkLightEntityDescription + def __init__( self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkLightEntityDescription, *, - light_module: Light, parent: Device | None = None, ) -> None: """Initialize the light.""" - self._parent = parent + super().__init__(device, coordinator, description, parent=parent) + + light_module = device.modules[Module.Light] self._light_module = light_module - # If _attr_name is None the entity name will be the device name - self._attr_name = None if parent is None else device.alias modes: set[ColorMode] = {ColorMode.ONOFF} if color_temp_feat := light_module.get_feature("color_temp"): modes.add(ColorMode.COLOR_TEMP) @@ -219,8 +242,6 @@ class TPLinkLightEntity(CoordinatedTPLinkEntity, LightEntity): # If the light supports only a single color mode, set it now self._fixed_color_mode = next(iter(self._attr_supported_color_modes)) - super().__init__(device, coordinator, parent=parent) - def _get_unique_id(self) -> str: """Return unique ID for the entity.""" # For historical reasons the light platform uses the mac address as @@ -367,13 +388,33 @@ class TPLinkLightEffectEntity(TPLinkLightEntity): self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkLightEntityDescription, *, - light_module: Light, - effect_module: LightEffect, + parent: Device | None = None, ) -> None: """Initialize the light strip.""" - self._effect_module = effect_module - super().__init__(device, coordinator, light_module=light_module) + super().__init__(device, coordinator, description, parent=parent) + + self._effect_module = device.modules[Module.LightEffect] + + async def async_added_to_hass(self) -> None: + """Call update attributes after the device is added to the platform.""" + await super().async_added_to_hass() + + self._register_effects_services() + + def _register_effects_services(self) -> None: + if self._effect_module.has_custom_effects: + self.platform.async_register_entity_service( + SERVICE_RANDOM_EFFECT, + RANDOM_EFFECT_DICT, + "async_set_random_effect", + ) + self.platform.async_register_entity_service( + SERVICE_SEQUENCE_EFFECT, + SEQUENCE_EFFECT_DICT, + "async_set_sequence_effect", + ) @callback def _async_update_attrs(self) -> bool: diff --git a/homeassistant/components/tplink/siren.py b/homeassistant/components/tplink/siren.py index bd1bfcead6d..0c15477ee78 100644 --- a/homeassistant/components/tplink/siren.py +++ b/homeassistant/components/tplink/siren.py @@ -2,24 +2,48 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any from kasa import Device, Module from kasa.smart.modules.alarm import Alarm -from homeassistant.components.siren import SirenEntity, SirenEntityFeature +from homeassistant.components.siren import ( + SirenEntity, + SirenEntityDescription, + SirenEntityFeature, +) from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback from . import TPLinkConfigEntry from .coordinator import TPLinkDataUpdateCoordinator -from .entity import CoordinatedTPLinkEntity, async_refresh_after +from .entity import ( + CoordinatedTPLinkModuleEntity, + TPLinkModuleEntityDescription, + async_refresh_after, +) # Coordinator is used to centralize the data updates # For actions the integration handles locking of concurrent device request PARALLEL_UPDATES = 0 +@dataclass(frozen=True, kw_only=True) +class TPLinkSirenEntityDescription( + SirenEntityDescription, TPLinkModuleEntityDescription +): + """Base class for siren entity description.""" + + +SIREN_DESCRIPTIONS: tuple[TPLinkSirenEntityDescription, ...] = ( + TPLinkSirenEntityDescription( + key="siren", + exists_fn=lambda dev, _: Module.Alarm in dev.modules, + ), +) + + async def async_setup_entry( hass: HomeAssistant, config_entry: TPLinkConfigEntry, @@ -30,24 +54,45 @@ async def async_setup_entry( parent_coordinator = data.parent_coordinator device = parent_coordinator.device - if Module.Alarm in device.modules: - async_add_entities([TPLinkSirenEntity(device, parent_coordinator)]) + known_child_device_ids: set[str] = set() + first_check = True + + def _check_device() -> None: + entities = CoordinatedTPLinkModuleEntity.entities_for_device_and_its_children( + hass=hass, + device=device, + coordinator=parent_coordinator, + entity_class=TPLinkSirenEntity, + descriptions=SIREN_DESCRIPTIONS, + known_child_device_ids=known_child_device_ids, + first_check=first_check, + ) + async_add_entities(entities) + + _check_device() + first_check = False + config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device)) -class TPLinkSirenEntity(CoordinatedTPLinkEntity, SirenEntity): +class TPLinkSirenEntity(CoordinatedTPLinkModuleEntity, SirenEntity): """Representation of a tplink siren entity.""" _attr_name = None _attr_supported_features = SirenEntityFeature.TURN_OFF | SirenEntityFeature.TURN_ON + entity_description: TPLinkSirenEntityDescription + def __init__( self, device: Device, coordinator: TPLinkDataUpdateCoordinator, + description: TPLinkSirenEntityDescription, + *, + parent: Device | None = None, ) -> None: """Initialize the siren entity.""" + super().__init__(device, coordinator, description, parent=parent) self._alarm_module: Alarm = device.modules[Module.Alarm] - super().__init__(device, coordinator) @async_refresh_after async def async_turn_on(self, **kwargs: Any) -> None: diff --git a/tests/components/tplink/snapshots/test_climate.ambr b/tests/components/tplink/snapshots/test_climate.ambr index 6823c373b68..e0173e8f59e 100644 --- a/tests/components/tplink/snapshots/test_climate.ambr +++ b/tests/components/tplink/snapshots/test_climate.ambr @@ -91,6 +91,6 @@ 'serial_number': None, 'suggested_area': None, 'sw_version': '1.0.0', - 'via_device_id': None, + 'via_device_id': , }) # --- diff --git a/tests/components/tplink/test_camera.py b/tests/components/tplink/test_camera.py index ceb74e3a61a..4b062c4d0b2 100644 --- a/tests/components/tplink/test_camera.py +++ b/tests/components/tplink/test_camera.py @@ -123,7 +123,7 @@ async def test_handle_mjpeg_stream_not_supported( hass: HomeAssistant, mock_camera_config_entry: MockConfigEntry, ) -> None: - """Test handle_async_mjpeg_stream.""" + """Test no stream if stream_rtsp_url is None after creation.""" mock_device = _mocked_device( modules=[Module.Camera], alias="my_camera", @@ -132,17 +132,17 @@ async def test_handle_mjpeg_stream_not_supported( ) mock_camera = mock_device.modules[Module.Camera] - mock_camera.stream_rtsp_url.return_value = None + mock_camera.stream_rtsp_url.side_effect = ("foo", None) await setup_platform_for_device( hass, mock_camera_config_entry, Platform.CAMERA, mock_device ) mock_request = make_mocked_request("GET", "/", headers={"token": "x"}) - stream = await async_get_mjpeg_stream( + mjpeg_stream = await async_get_mjpeg_stream( hass, mock_request, "camera.my_camera_live_view" ) - assert stream is None + assert mjpeg_stream is None async def test_camera_image( diff --git a/tests/components/tplink/test_init.py b/tests/components/tplink/test_init.py index 1fbd79c16c2..ef0ae3b6827 100644 --- a/tests/components/tplink/test_init.py +++ b/tests/components/tplink/test_init.py @@ -20,7 +20,6 @@ from kasa import ( from kasa.iot import IotStrip import pytest -from homeassistant import setup from homeassistant.components import tplink from homeassistant.components.tplink.const import ( CONF_AES_KEYS, @@ -68,7 +67,9 @@ from .const import ( DEVICE_ID, DEVICE_ID_MAC, IP_ADDRESS, + IP_ADDRESS3, MAC_ADDRESS, + MAC_ADDRESS3, MODEL, ) @@ -162,7 +163,7 @@ async def test_dimmer_switch_unique_id_fix_original_entity_still_exists( _patch_single_discovery(device=dimmer), _patch_connect(device=dimmer), ): - await setup.async_setup_component(hass, DOMAIN, {}) + await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done(wait_background_tasks=True) migrated_dimmer_entity_reg = entity_registry.async_get_or_create( @@ -374,7 +375,7 @@ async def test_update_attrs_fails_in_init( assert entity state = hass.states.get(entity_id) assert state.state == STATE_UNAVAILABLE - assert "Unable to read data for MockLight None:" in caplog.text + assert f"Unable to read data for MockLight {entity_id}:" in caplog.text async def test_update_attrs_fails_on_update( @@ -839,7 +840,7 @@ async def test_migrate_remove_device_config( @pytest.mark.parametrize( - ("device_type"), + ("parent_device_type"), [ (Device), (IotStrip), @@ -859,7 +860,7 @@ async def test_migrate_remove_device_config( ], ) @pytest.mark.usefixtures("entity_registry_enabled_by_default") -async def test_automatic_device_addition_and_removal( +async def test_automatic_feature_device_addition_and_removal( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_connect: AsyncMock, @@ -870,9 +871,9 @@ async def test_automatic_device_addition_and_removal( platform: str, feature_id: str, translated_name: str, - device_type: type, + parent_device_type: type, ) -> None: - """Test for automatic device addition and removal.""" + """Test for automatic device with features addition and removal.""" children = { f"child{index}": _mocked_device( @@ -889,7 +890,7 @@ async def test_automatic_device_addition_and_removal( children=[children["child1"], children["child2"]], features=[feature_id], device_type=DeviceType.Hub, - spec=device_type, + spec=parent_device_type, device_id="hub_parent", ) @@ -985,3 +986,167 @@ async def test_automatic_device_addition_and_removal( ) assert device_entry assert device_entry.via_device_id == parent_device.id + + +@pytest.mark.parametrize( + ("platform", "modules", "features", "translated_name", "child_device_type"), + [ + pytest.param( + "camera", [Module.Camera], [], "live_view", DeviceType.Camera, id="camera" + ), + pytest.param("fan", [Module.Fan], [], None, DeviceType.Fan, id="fan"), + pytest.param("siren", [Module.Alarm], [], None, DeviceType.Camera, id="siren"), + pytest.param("light", [Module.Light], [], None, DeviceType.Camera, id="light"), + pytest.param( + "light", + [Module.Light, Module.LightEffect], + [], + None, + DeviceType.Camera, + id="light_effect", + ), + pytest.param( + "climate", + [], + ["state", "thermostat_mode", "temperature", "target_temperature"], + None, + DeviceType.Thermostat, + id="climate", + ), + ], +) +@pytest.mark.usefixtures("entity_registry_enabled_by_default") +async def test_automatic_module_device_addition_and_removal( + hass: HomeAssistant, + mock_camera_config_entry: MockConfigEntry, + mock_connect: AsyncMock, + mock_discovery: AsyncMock, + entity_registry: er.EntityRegistry, + device_registry: dr.DeviceRegistry, + freezer: FrozenDateTimeFactory, + platform: str, + modules: list[str], + features: list[str], + translated_name: str | None, + child_device_type: DeviceType, +) -> None: + """Test for automatic device with modules addition and removal.""" + + children = { + f"child{index}": _mocked_device( + alias=f"child {index}", + modules=modules, + features=features, + device_type=child_device_type, + device_id=f"child{index}", + ) + for index in range(1, 5) + } + + mock_device = _mocked_device( + alias="hub", + children=[children["child1"], children["child2"]], + features=["ssid"], + device_type=DeviceType.Hub, + device_id="hub_parent", + ip_address=IP_ADDRESS3, + mac=MAC_ADDRESS3, + ) + + with override_side_effect(mock_connect["connect"], lambda *_, **__: mock_device): + mock_camera_config_entry.add_to_hass(hass) + await hass.config_entries.async_setup(mock_camera_config_entry.entry_id) + await hass.async_block_till_done() + + for child_id in (1, 2): + sub_id = f"_{translated_name}" if translated_name else "" + entity_id = f"{platform}.child_{child_id}{sub_id}" + state = hass.states.get(entity_id) + assert state + assert entity_registry.async_get(entity_id) + + parent_device = device_registry.async_get_device( + identifiers={(DOMAIN, "hub_parent")} + ) + assert parent_device + + for device_id in ("child1", "child2"): + device_entry = device_registry.async_get_device( + identifiers={(DOMAIN, device_id)} + ) + assert device_entry + assert device_entry.via_device_id == parent_device.id + + # Remove one of the devices + mock_device.children = [children["child1"]] + freezer.tick(5) + async_fire_time_changed(hass) + + sub_id = f"_{translated_name}" if translated_name else "" + entity_id = f"{platform}.child_2{sub_id}" + state = hass.states.get(entity_id) + assert state is None + assert entity_registry.async_get(entity_id) is None + + assert device_registry.async_get_device(identifiers={(DOMAIN, "child2")}) is None + + # Re-dd the previously removed child device + mock_device.children = [ + children["child1"], + children["child2"], + ] + freezer.tick(5) + async_fire_time_changed(hass) + + for child_id in (1, 2): + sub_id = f"_{translated_name}" if translated_name else "" + entity_id = f"{platform}.child_{child_id}{sub_id}" + state = hass.states.get(entity_id) + assert state + assert entity_registry.async_get(entity_id) + + for device_id in ("child1", "child2"): + device_entry = device_registry.async_get_device( + identifiers={(DOMAIN, device_id)} + ) + assert device_entry + assert device_entry.via_device_id == parent_device.id + + # Add child devices + mock_device.children = [children["child1"], children["child3"], children["child4"]] + freezer.tick(5) + async_fire_time_changed(hass) + + for child_id in (1, 3, 4): + sub_id = f"_{translated_name}" if translated_name else "" + entity_id = f"{platform}.child_{child_id}{sub_id}" + state = hass.states.get(entity_id) + assert state + assert entity_registry.async_get(entity_id) + + for device_id in ("child1", "child3", "child4"): + assert device_registry.async_get_device(identifiers={(DOMAIN, device_id)}) + + # Add the previously removed child device + mock_device.children = [ + children["child1"], + children["child2"], + children["child3"], + children["child4"], + ] + freezer.tick(5) + async_fire_time_changed(hass) + + for child_id in (1, 2, 3, 4): + sub_id = f"_{translated_name}" if translated_name else "" + entity_id = f"{platform}.child_{child_id}{sub_id}" + state = hass.states.get(entity_id) + assert state + assert entity_registry.async_get(entity_id) + + for device_id in ("child1", "child2", "child3", "child4"): + device_entry = device_registry.async_get_device( + identifiers={(DOMAIN, device_id)} + ) + assert device_entry + assert device_entry.via_device_id == parent_device.id