Add dynamic child device handling to tplink integration (#135229)

Add dynamic child device handling to tplink integration. For child devices that could be added/removed to hubs.
This commit is contained in:
Steven B. 2025-01-15 19:45:06 +00:00 committed by GitHub
parent c6cab3259c
commit 51e3bf42f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 403 additions and 112 deletions

View File

@ -18,7 +18,6 @@ from kasa import (
KasaException,
)
from kasa.httpclient import get_cookie_jar
from kasa.iot import IotStrip
from homeassistant import config_entries
from homeassistant.components import network
@ -235,17 +234,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
parent_coordinator = TPLinkDataUpdateCoordinator(
hass, device, timedelta(seconds=5), entry
)
child_coordinators: list[TPLinkDataUpdateCoordinator] = []
# The iot HS300 allows a limited number of concurrent requests and fetching the
# emeter information requires separate ones so create child coordinators here.
if isinstance(device, IotStrip):
child_coordinators = [
# The child coordinators only update energy data so we can
# set a longer update interval to avoid flooding the device
TPLinkDataUpdateCoordinator(hass, child, timedelta(seconds=60), entry)
for child in device.children
]
camera_creds: Credentials | None = None
if camera_creds_dict := entry.data.get(CONF_CAMERA_CREDENTIALS):
@ -254,9 +242,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
)
live_view = entry.data.get(CONF_LIVE_VIEW)
entry.runtime_data = TPLinkData(
parent_coordinator, child_coordinators, camera_creds, live_view
)
entry.runtime_data = TPLinkData(parent_coordinator, camera_creds, live_view)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True

View File

@ -8,6 +8,7 @@ from typing import Final, cast
from kasa import Feature
from homeassistant.components.binary_sensor import (
DOMAIN as BINARY_SENSOR_DOMAIN,
BinarySensorDeviceClass,
BinarySensorEntity,
BinarySensorEntityDescription,
@ -16,6 +17,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import TPLinkConfigEntry
from .deprecate import async_cleanup_deprecated
from .entity import CoordinatedTPLinkFeatureEntity, TPLinkFeatureEntityDescription
@ -73,9 +75,12 @@ async def async_setup_entry(
"""Set up sensors."""
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
children_coordinators = data.children_coordinators
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -83,10 +88,18 @@ async def async_setup_entry(
feature_type=Feature.Type.BinarySensor,
entity_class=TPLinkBinarySensorEntity,
descriptions=BINARYSENSOR_DESCRIPTIONS_MAP,
child_coordinators=children_coordinators,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(
hass, BINARY_SENSOR_DOMAIN, config_entry.entry_id, entities
)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkBinarySensorEntity(CoordinatedTPLinkFeatureEntity, BinarySensorEntity):
"""Representation of a TPLink binary sensor."""

View File

@ -83,9 +83,11 @@ async def async_setup_entry(
"""Set up buttons."""
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
children_coordinators = data.children_coordinators
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -93,11 +95,16 @@ async def async_setup_entry(
feature_type=Feature.Type.Action,
entity_class=TPLinkButtonEntity,
descriptions=BUTTON_DESCRIPTIONS_MAP,
child_coordinators=children_coordinators,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(hass, BUTTON_DOMAIN, config_entry.entry_id, entities)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkButtonEntity(CoordinatedTPLinkFeatureEntity, ButtonEntity):
"""Representation of a TPLink button entity."""

View File

@ -7,10 +7,12 @@ from datetime import timedelta
import logging
from kasa import AuthenticationError, Credentials, Device, KasaException
from kasa.iot import IotStrip
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
@ -24,7 +26,6 @@ class TPLinkData:
"""Data for the tplink integration."""
parent_coordinator: TPLinkDataUpdateCoordinator
children_coordinators: list[TPLinkDataUpdateCoordinator]
camera_credentials: Credentials | None
live_view: bool | None
@ -60,6 +61,9 @@ class TPLinkDataUpdateCoordinator(DataUpdateCoordinator[None]):
hass, _LOGGER, cooldown=REQUEST_REFRESH_DELAY, immediate=False
),
)
self._previous_child_device_ids = {child.device_id for child in device.children}
self.removed_child_device_ids: set[str] = set()
self._child_coordinators: dict[str, TPLinkDataUpdateCoordinator] = {}
async def _async_update_data(self) -> None:
"""Fetch all device and sensor data from api."""
@ -83,3 +87,48 @@ class TPLinkDataUpdateCoordinator(DataUpdateCoordinator[None]):
"exc": str(ex),
},
) from ex
await self._process_child_devices()
async def _process_child_devices(self) -> None:
"""Process child devices and remove stale devices."""
current_child_device_ids = {child.device_id for child in self.device.children}
if (
stale_device_ids := self._previous_child_device_ids
- current_child_device_ids
):
device_registry = dr.async_get(self.hass)
for device_id in stale_device_ids:
device = device_registry.async_get_device(
identifiers={(DOMAIN, device_id)}
)
if device:
device_registry.async_update_device(
device_id=device.id,
remove_config_entry_id=self.config_entry.entry_id,
)
child_coordinator = self._child_coordinators.pop(device_id, None)
if child_coordinator:
await child_coordinator.async_shutdown()
self._previous_child_device_ids = current_child_device_ids
self.removed_child_device_ids = stale_device_ids
def get_child_coordinator(
self,
child: Device,
) -> TPLinkDataUpdateCoordinator:
"""Get separate child coordinator for a device or self if not needed."""
# The iot HS300 allows a limited number of concurrent requests and fetching the
# emeter information requires separate ones so create child coordinators here.
if isinstance(self.device, IotStrip):
if not (child_coordinator := self._child_coordinators.get(child.device_id)):
# The child coordinators only update energy data so we can
# set a longer update interval to avoid flooding the device
child_coordinator = TPLinkDataUpdateCoordinator(
self.hass, child, timedelta(seconds=60), self.config_entry
)
self._child_coordinators[child.device_id] = child_coordinator
return child_coordinator
return self

View File

@ -434,7 +434,8 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC):
feature_type: Feature.Type,
entity_class: type[_E],
descriptions: Mapping[str, _D],
child_coordinators: list[TPLinkDataUpdateCoordinator] | None = None,
known_child_device_ids: set[str],
first_check: bool,
) -> list[_E]:
"""Create entities for device and its children.
@ -442,6 +443,8 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC):
"""
entities: list[_E] = []
# Add parent entities before children so via_device id works.
# Only add the parent entities the first time
if first_check:
entities.extend(
cls._entities_for_device(
hass,
@ -452,18 +455,42 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC):
descriptions=descriptions,
)
)
if device.children:
_LOGGER.debug("Initializing device with %s children", len(device.children))
for idx, child in enumerate(device.children):
# HS300 does not like too many concurrent requests and its
# emeter data requires a request for each socket, so we receive
# separate coordinators.
if child_coordinators:
child_coordinator = child_coordinators[idx]
else:
child_coordinator = coordinator
entities.extend(
cls._entities_for_device(
# Remove any device ids removed via the coordinator so they can be re-added
for removed_child_id in coordinator.removed_child_device_ids:
_LOGGER.debug(
"Removing %s from known %s child ids for device %s"
"as it has been removed by the coordinator",
removed_child_id,
entity_class.__name__,
device.host,
)
known_child_device_ids.discard(removed_child_id)
current_child_devices = {child.device_id: child for child in device.children}
current_child_device_ids = set(current_child_devices.keys())
new_child_device_ids = current_child_device_ids - known_child_device_ids
children = []
if new_child_device_ids:
children = [
child
for child_id, child in current_child_devices.items()
if child_id in new_child_device_ids
]
known_child_device_ids.update(new_child_device_ids)
if children:
_LOGGER.debug(
"Getting %s entities for %s child devices on device %s",
entity_class.__name__,
len(children),
device.host,
)
for child in children:
child_coordinator = coordinator.get_child_coordinator(child)
child_entities = cls._entities_for_device(
hass,
child,
coordinator=child_coordinator,
@ -472,6 +499,13 @@ class CoordinatedTPLinkFeatureEntity(CoordinatedTPLinkEntity, ABC):
descriptions=descriptions,
parent=device,
)
_LOGGER.debug(
"Device %s, found %s child %s entities for child id %s",
device.host,
len(entities),
entity_class.__name__,
child.device_id,
)
entities.extend(child_entities)
return entities

View File

@ -9,6 +9,7 @@ from typing import Final, cast
from kasa import Device, Feature
from homeassistant.components.number import (
DOMAIN as NUMBER_DOMAIN,
NumberEntity,
NumberEntityDescription,
NumberMode,
@ -17,6 +18,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import TPLinkConfigEntry
from .deprecate import async_cleanup_deprecated
from .entity import (
CoordinatedTPLinkFeatureEntity,
TPLinkDataUpdateCoordinator,
@ -77,8 +79,11 @@ async def async_setup_entry(
"""Set up number entities."""
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
children_coordinators = data.children_coordinators
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -86,11 +91,16 @@ async def async_setup_entry(
feature_type=Feature.Type.Number,
entity_class=TPLinkNumberEntity,
descriptions=NUMBER_DESCRIPTIONS_MAP,
child_coordinators=children_coordinators,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(hass, NUMBER_DOMAIN, config_entry.entry_id, entities)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkNumberEntity(CoordinatedTPLinkFeatureEntity, NumberEntity):
"""Representation of a feature-based TPLink number entity."""

View File

@ -7,11 +7,16 @@ from typing import Final, cast
from kasa import Device, Feature
from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SelectEntity,
SelectEntityDescription,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import TPLinkConfigEntry
from .deprecate import async_cleanup_deprecated
from .entity import (
CoordinatedTPLinkFeatureEntity,
TPLinkDataUpdateCoordinator,
@ -54,9 +59,11 @@ async def async_setup_entry(
"""Set up select entities."""
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
children_coordinators = data.children_coordinators
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -64,10 +71,16 @@ async def async_setup_entry(
feature_type=Feature.Type.Choice,
entity_class=TPLinkSelectEntity,
descriptions=SELECT_DESCRIPTIONS_MAP,
child_coordinators=children_coordinators,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(hass, SELECT_DOMAIN, config_entry.entry_id, entities)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkSelectEntity(CoordinatedTPLinkFeatureEntity, SelectEntity):
"""Representation of a tplink select entity."""

View File

@ -129,9 +129,11 @@ async def async_setup_entry(
"""Set up sensors."""
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
children_coordinators = data.children_coordinators
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -139,11 +141,16 @@ async def async_setup_entry(
feature_type=Feature.Type.Sensor,
entity_class=TPLinkSensorEntity,
descriptions=SENSOR_DESCRIPTIONS_MAP,
child_coordinators=children_coordinators,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(hass, SENSOR_DOMAIN, config_entry.entry_id, entities)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkSensorEntity(CoordinatedTPLinkFeatureEntity, SensorEntity):
"""Representation of a feature-based TPLink sensor."""

View File

@ -8,11 +8,16 @@ from typing import Any, cast
from kasa import Feature
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
from homeassistant.components.switch import (
DOMAIN as SWITCH_DOMAIN,
SwitchEntity,
SwitchEntityDescription,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import TPLinkConfigEntry
from .deprecate import async_cleanup_deprecated
from .entity import (
CoordinatedTPLinkFeatureEntity,
TPLinkFeatureEntityDescription,
@ -84,7 +89,10 @@ async def async_setup_entry(
data = config_entry.runtime_data
parent_coordinator = data.parent_coordinator
device = parent_coordinator.device
known_child_device_ids: set[str] = set()
first_check = True
def _check_device() -> None:
entities = CoordinatedTPLinkFeatureEntity.entities_for_device_and_its_children(
hass=hass,
device=device,
@ -92,10 +100,16 @@ async def async_setup_entry(
feature_type=Feature.Switch,
entity_class=TPLinkSwitch,
descriptions=SWITCH_DESCRIPTIONS_MAP,
known_child_device_ids=known_child_device_ids,
first_check=first_check,
)
async_cleanup_deprecated(hass, SWITCH_DOMAIN, config_entry.entry_id, entities)
async_add_entities(entities)
_check_device()
first_check = False
config_entry.async_on_unload(parent_coordinator.async_add_listener(_check_device))
class TPLinkSwitch(CoordinatedTPLinkFeatureEntity, SwitchEntity):
"""Representation of a feature-based TPLink switch."""

View File

@ -8,7 +8,16 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
from freezegun.api import FrozenDateTimeFactory
from kasa import AuthenticationError, DeviceConfig, Feature, KasaException, Module
from kasa import (
AuthenticationError,
Device,
DeviceConfig,
DeviceType,
Feature,
KasaException,
Module,
)
from kasa.iot import IotStrip
import pytest
from homeassistant import setup
@ -827,3 +836,152 @@ async def test_migrate_remove_device_config(
assert entry.data == expected_entry_data
assert "Migration to version 1.5 complete" in caplog.text
@pytest.mark.parametrize(
("device_type"),
[
(Device),
(IotStrip),
],
)
@pytest.mark.parametrize(
("platform", "feature_id", "translated_name"),
[
pytest.param("switch", "led", "led", id="switch"),
pytest.param(
"sensor", "current_consumption", "current_consumption", id="sensor"
),
pytest.param("binary_sensor", "overheated", "overheated", id="binary_sensor"),
pytest.param("number", "smooth_transition_on", "smooth_on", id="number"),
pytest.param("select", "light_preset", "light_preset", id="select"),
pytest.param("button", "reboot", "restart", id="button"),
],
)
@pytest.mark.usefixtures("entity_registry_enabled_by_default")
async def test_automatic_device_addition_and_removal(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_connect: AsyncMock,
mock_discovery: AsyncMock,
entity_registry: er.EntityRegistry,
device_registry: dr.DeviceRegistry,
freezer: FrozenDateTimeFactory,
platform: str,
feature_id: str,
translated_name: str,
device_type: type,
) -> None:
"""Test for automatic device addition and removal."""
children = {
f"child{index}": _mocked_device(
alias=f"child {index}",
features=[feature_id],
device_type=DeviceType.StripSocket,
device_id=f"child{index}",
)
for index in range(1, 5)
}
mock_device = _mocked_device(
alias="hub",
children=[children["child1"], children["child2"]],
features=[feature_id],
device_type=DeviceType.Hub,
spec=device_type,
device_id="hub_parent",
)
with override_side_effect(mock_connect["connect"], lambda *_, **__: mock_device):
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
for child_id in (1, 2):
entity_id = f"{platform}.child_{child_id}_{translated_name}"
state = hass.states.get(entity_id)
assert state
assert entity_registry.async_get(entity_id)
parent_device = device_registry.async_get_device(
identifiers={(DOMAIN, "hub_parent")}
)
assert parent_device
for device_id in ("child1", "child2"):
device_entry = device_registry.async_get_device(
identifiers={(DOMAIN, device_id)}
)
assert device_entry
assert device_entry.via_device_id == parent_device.id
# Remove one of the devices
mock_device.children = [children["child1"]]
freezer.tick(5)
async_fire_time_changed(hass)
entity_id = f"{platform}.child_2_{translated_name}"
state = hass.states.get(entity_id)
assert state is None
assert entity_registry.async_get(entity_id) is None
assert device_registry.async_get_device(identifiers={(DOMAIN, "child2")}) is None
# Re-dd the previously removed child device
mock_device.children = [
children["child1"],
children["child2"],
]
freezer.tick(5)
async_fire_time_changed(hass)
for child_id in (1, 2):
entity_id = f"{platform}.child_{child_id}_{translated_name}"
state = hass.states.get(entity_id)
assert state
assert entity_registry.async_get(entity_id)
for device_id in ("child1", "child2"):
device_entry = device_registry.async_get_device(
identifiers={(DOMAIN, device_id)}
)
assert device_entry
assert device_entry.via_device_id == parent_device.id
# Add child devices
mock_device.children = [children["child1"], children["child3"], children["child4"]]
freezer.tick(5)
async_fire_time_changed(hass)
for child_id in (1, 3, 4):
entity_id = f"{platform}.child_{child_id}_{translated_name}"
state = hass.states.get(entity_id)
assert state
assert entity_registry.async_get(entity_id)
for device_id in ("child1", "child3", "child4"):
assert device_registry.async_get_device(identifiers={(DOMAIN, device_id)})
# Add the previously removed child device
mock_device.children = [
children["child1"],
children["child2"],
children["child3"],
children["child4"],
]
freezer.tick(5)
async_fire_time_changed(hass)
for child_id in (1, 2, 3, 4):
entity_id = f"{platform}.child_{child_id}_{translated_name}"
state = hass.states.get(entity_id)
assert state
assert entity_registry.async_get(entity_id)
for device_id in ("child1", "child2", "child3", "child4"):
device_entry = device_registry.async_get_device(
identifiers={(DOMAIN, device_id)}
)
assert device_entry
assert device_entry.via_device_id == parent_device.id