From 2a36ec3e21a34b63d8b2e0b66a78ce5e13bf41fc Mon Sep 17 00:00:00 2001 From: MarkGodwin <10632972+MarkGodwin@users.noreply.github.com> Date: Sun, 22 Sep 2024 16:05:29 +0100 Subject: [PATCH] Automatically remove unregistered TP-Link Omada devices at start up (#124153) * Adding coordinator for omada device list * Remove dead omada devices at startup * Tidy up tests * Address PR feedback * Returned to use of read-only properties for coordinators. Tidied up parameters some more * Update homeassistant/components/tplink_omada/controller.py * Update homeassistant/components/tplink_omada/controller.py * Update homeassistant/components/tplink_omada/controller.py --------- Co-authored-by: Joost Lekkerkerker --- .../components/tplink_omada/__init__.py | 26 +++++++-- .../components/tplink_omada/binary_sensor.py | 2 +- .../components/tplink_omada/controller.py | 54 ++++++++++-------- .../components/tplink_omada/coordinator.py | 23 +++++++- .../components/tplink_omada/device_tracker.py | 3 +- .../components/tplink_omada/entity.py | 3 +- .../components/tplink_omada/switch.py | 2 +- .../components/tplink_omada/update.py | 55 +++++++++++++------ tests/components/tplink_omada/test_init.py | 47 ++++++++++++++++ 9 files changed, 164 insertions(+), 51 deletions(-) create mode 100644 tests/components/tplink_omada/test_init.py diff --git a/homeassistant/components/tplink_omada/__init__.py b/homeassistant/components/tplink_omada/__init__.py index 19b3d58dbd4..9945df2bbae 100644 --- a/homeassistant/components/tplink_omada/__init__.py +++ b/homeassistant/components/tplink_omada/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from tplink_omada_client import OmadaSite +from tplink_omada_client.devices import OmadaListDevice from tplink_omada_client.exceptions import ( ConnectionFailed, LoginFailed, @@ -14,6 +15,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady +from homeassistant.helpers import device_registry as dr from .config_flow import CONF_SITE, create_omada_client from .const import DOMAIN @@ -52,13 +54,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: site_client = await client.get_site_client(OmadaSite("", entry.data[CONF_SITE])) controller = OmadaSiteController(hass, site_client) - gateway_coordinator = await controller.get_gateway_coordinator() - if gateway_coordinator: - await gateway_coordinator.async_config_entry_first_refresh() - await controller.get_clients_coordinator().async_config_entry_first_refresh() + await controller.initialize_first_refresh() hass.data[DOMAIN][entry.entry_id] = controller + _remove_old_devices(hass, entry, controller.devices_coordinator.data) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True @@ -70,3 +71,20 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data[DOMAIN].pop(entry.entry_id) return unload_ok + + +def _remove_old_devices( + hass: HomeAssistant, entry: ConfigEntry, omada_devices: dict[str, OmadaListDevice] +) -> None: + device_registry = dr.async_get(hass) + + for registered_device in device_registry.devices.get_devices_for_config_entry_id( + entry.entry_id + ): + mac = next( + (i[1] for i in registered_device.identifiers if i[0] == DOMAIN), None + ) + if mac and mac not in omada_devices: + device_registry.async_update_device( + registered_device.id, remove_config_entry_id=entry.entry_id + ) diff --git a/homeassistant/components/tplink_omada/binary_sensor.py b/homeassistant/components/tplink_omada/binary_sensor.py index c0304c4d1b2..c3941ff7595 100644 --- a/homeassistant/components/tplink_omada/binary_sensor.py +++ b/homeassistant/components/tplink_omada/binary_sensor.py @@ -34,7 +34,7 @@ async def async_setup_entry( """Set up binary sensors.""" controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] - gateway_coordinator = await controller.get_gateway_coordinator() + gateway_coordinator = controller.gateway_coordinator if not gateway_coordinator: return diff --git a/homeassistant/components/tplink_omada/controller.py b/homeassistant/components/tplink_omada/controller.py index d92a6f37e24..658286981f9 100644 --- a/homeassistant/components/tplink_omada/controller.py +++ b/homeassistant/components/tplink_omada/controller.py @@ -7,6 +7,7 @@ from homeassistant.core import HomeAssistant from .coordinator import ( OmadaClientsCoordinator, + OmadaDevicesCoordinator, OmadaGatewayCoordinator, OmadaSwitchPortCoordinator, ) @@ -16,15 +17,33 @@ class OmadaSiteController: """Controller for the Omada SDN site.""" _gateway_coordinator: OmadaGatewayCoordinator | None = None - _initialized_gateway_coordinator = False - _clients_coordinator: OmadaClientsCoordinator | None = None - def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None: + def __init__( + self, + hass: HomeAssistant, + omada_client: OmadaSiteClient, + ) -> None: """Create the controller.""" self._hass = hass self._omada_client = omada_client self._switch_port_coordinators: dict[str, OmadaSwitchPortCoordinator] = {} + self._devices_coordinator = OmadaDevicesCoordinator(hass, omada_client) + self._clients_coordinator = OmadaClientsCoordinator(hass, omada_client) + + async def initialize_first_refresh(self) -> None: + """Initialize the all coordinators, and perform first refresh.""" + await self._devices_coordinator.async_config_entry_first_refresh() + + devices = self._devices_coordinator.data.values() + gateway = next((d for d in devices if d.type == "gateway"), None) + if gateway: + self._gateway_coordinator = OmadaGatewayCoordinator( + self._hass, self._omada_client, gateway.mac + ) + await self._gateway_coordinator.async_config_entry_first_refresh() + + await self.clients_coordinator.async_config_entry_first_refresh() @property def omada_client(self) -> OmadaSiteClient: @@ -42,26 +61,17 @@ class OmadaSiteController: return self._switch_port_coordinators[switch.mac] - async def get_gateway_coordinator(self) -> OmadaGatewayCoordinator | None: - """Get coordinator for site's gateway, or None if there is no gateway.""" - if not self._initialized_gateway_coordinator: - self._initialized_gateway_coordinator = True - devices = await self._omada_client.get_devices() - gateway = next((d for d in devices if d.type == "gateway"), None) - if not gateway: - return None - - self._gateway_coordinator = OmadaGatewayCoordinator( - self._hass, self._omada_client, gateway.mac - ) - + @property + def gateway_coordinator(self) -> OmadaGatewayCoordinator | None: + """Gets the coordinator for site's gateway, or None if there is no gateway.""" return self._gateway_coordinator - def get_clients_coordinator(self) -> OmadaClientsCoordinator: - """Get coordinator for site's clients.""" - if not self._clients_coordinator: - self._clients_coordinator = OmadaClientsCoordinator( - self._hass, self._omada_client - ) + @property + def devices_coordinator(self) -> OmadaDevicesCoordinator: + """Gets the coordinator for site's devices.""" + return self._devices_coordinator + @property + def clients_coordinator(self) -> OmadaClientsCoordinator: + """Gets the coordinator for site's clients.""" return self._clients_coordinator diff --git a/homeassistant/components/tplink_omada/coordinator.py b/homeassistant/components/tplink_omada/coordinator.py index da0a79ef991..e4f15e6567c 100644 --- a/homeassistant/components/tplink_omada/coordinator.py +++ b/homeassistant/components/tplink_omada/coordinator.py @@ -6,7 +6,7 @@ import logging from tplink_omada_client import OmadaSiteClient, OmadaSwitchPortDetails from tplink_omada_client.clients import OmadaWirelessClient -from tplink_omada_client.devices import OmadaGateway, OmadaSwitch +from tplink_omada_client.devices import OmadaGateway, OmadaListDevice, OmadaSwitch from tplink_omada_client.exceptions import OmadaClientException from homeassistant.core import HomeAssistant @@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__) POLL_SWITCH_PORT = 300 POLL_GATEWAY = 300 POLL_CLIENTS = 300 +POLL_DEVICES = 900 class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]): @@ -27,14 +28,14 @@ class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]): hass: HomeAssistant, omada_client: OmadaSiteClient, name: str, - poll_delay: int = 300, + poll_delay: int | None = 300, ) -> None: """Initialize my coordinator.""" super().__init__( hass, _LOGGER, name=f"Omada API Data - {name}", - update_interval=timedelta(seconds=poll_delay), + update_interval=timedelta(seconds=poll_delay) if poll_delay else None, ) self.omada_client = omada_client @@ -91,6 +92,22 @@ class OmadaGatewayCoordinator(OmadaCoordinator[OmadaGateway]): return {self.mac: gateway} +class OmadaDevicesCoordinator(OmadaCoordinator[OmadaListDevice]): + """Coordinator for generic device lists from the controller.""" + + def __init__( + self, + hass: HomeAssistant, + omada_client: OmadaSiteClient, + ) -> None: + """Initialize my coordinator.""" + super().__init__(hass, omada_client, "DeviceList", POLL_CLIENTS) + + async def poll_update(self) -> dict[str, OmadaListDevice]: + """Poll the site's current registered Omada devices.""" + return {d.mac: d for d in await self.omada_client.get_devices()} + + class OmadaClientsCoordinator(OmadaCoordinator[OmadaWirelessClient]): """Coordinator for getting details about the site's connected clients.""" diff --git a/homeassistant/components/tplink_omada/device_tracker.py b/homeassistant/components/tplink_omada/device_tracker.py index be734592d11..12c519b883f 100644 --- a/homeassistant/components/tplink_omada/device_tracker.py +++ b/homeassistant/components/tplink_omada/device_tracker.py @@ -26,7 +26,6 @@ async def async_setup_entry( controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] - clients_coordinator = controller.get_clients_coordinator() site_id = config_entry.data[CONF_SITE] # Add all known WiFi devices as potentially tracked devices. They will only be @@ -34,7 +33,7 @@ async def async_setup_entry( async_add_entities( [ OmadaClientScannerEntity( - site_id, client.mac, client.name, clients_coordinator + site_id, client.mac, client.name, controller.clients_coordinator ) async for client in controller.omada_client.get_known_clients() if isinstance(client, OmadaWirelessClient) diff --git a/homeassistant/components/tplink_omada/entity.py b/homeassistant/components/tplink_omada/entity.py index 13ec7b3c6cb..213764aaa12 100644 --- a/homeassistant/components/tplink_omada/entity.py +++ b/homeassistant/components/tplink_omada/entity.py @@ -5,7 +5,6 @@ from typing import Any from tplink_omada_client.devices import OmadaDevice from homeassistant.helpers import device_registry as dr -from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.update_coordinator import CoordinatorEntity from .const import DOMAIN @@ -19,7 +18,7 @@ class OmadaDeviceEntity[_T: OmadaCoordinator[Any]](CoordinatorEntity[_T]): """Initialize the device.""" super().__init__(coordinator) self.device = device - self._attr_device_info = DeviceInfo( + self._attr_device_info = dr.DeviceInfo( connections={(dr.CONNECTION_NETWORK_MAC, device.mac)}, identifiers={(DOMAIN, device.mac)}, manufacturer="TP-Link", diff --git a/homeassistant/components/tplink_omada/switch.py b/homeassistant/components/tplink_omada/switch.py index 9f9eeceb866..12d4d4039ee 100644 --- a/homeassistant/components/tplink_omada/switch.py +++ b/homeassistant/components/tplink_omada/switch.py @@ -74,7 +74,7 @@ async def async_setup_entry( if desc.exists_func(switch, port) ) - gateway_coordinator = await controller.get_gateway_coordinator() + gateway_coordinator = controller.gateway_coordinator if gateway_coordinator: for gateway in gateway_coordinator.data.values(): entities.extend( diff --git a/homeassistant/components/tplink_omada/update.py b/homeassistant/components/tplink_omada/update.py index a7552263ff1..82c694a5ae4 100644 --- a/homeassistant/components/tplink_omada/update.py +++ b/homeassistant/components/tplink_omada/update.py @@ -21,10 +21,9 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN from .controller import OmadaSiteController -from .coordinator import OmadaCoordinator +from .coordinator import POLL_DEVICES, OmadaCoordinator, OmadaDevicesCoordinator from .entity import OmadaDeviceEntity -POLL_DELAY_IDLE = 6 * 60 * 60 POLL_DELAY_UPGRADE = 60 @@ -35,15 +34,28 @@ class FirmwareUpdateStatus(NamedTuple): firmware: OmadaFirmwareUpdate | None -class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module - """Coordinator for getting details about ports on a switch.""" +class OmadaFirmwareUpdateCoordinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module + """Coordinator for getting details about available firmware updates for Omada devices.""" - def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None: + def __init__( + self, + hass: HomeAssistant, + config_entry: ConfigEntry, + omada_client: OmadaSiteClient, + devices_coordinator: OmadaDevicesCoordinator, + ) -> None: """Initialize my coordinator.""" - super().__init__(hass, omada_client, "Firmware Updates", POLL_DELAY_IDLE) + super().__init__(hass, omada_client, "Firmware Updates", poll_delay=None) + + self._devices_coordinator = devices_coordinator + self._config_entry = config_entry + + config_entry.async_on_unload( + devices_coordinator.async_add_listener(self._handle_devices_update) + ) async def _get_firmware_updates(self) -> list[FirmwareUpdateStatus]: - devices = await self.omada_client.get_devices() + devices = self._devices_coordinator.data.values() updates = [ FirmwareUpdateStatus( @@ -55,12 +67,12 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): # for d in devices ] - # During a firmware upgrade, poll more frequently - self.update_interval = timedelta( + # During a firmware upgrade, poll device list more frequently + self._devices_coordinator.update_interval = timedelta( seconds=( POLL_DELAY_UPGRADE if any(u.device.fw_download for u in updates) - else POLL_DELAY_IDLE + else POLL_DEVICES ) ) return updates @@ -69,6 +81,14 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): # """Poll the state of Omada Devices firmware update availability.""" return {d.device.mac: d for d in await self._get_firmware_updates()} + @callback + def _handle_devices_update(self) -> None: + """Handle updated data from the devices coordinator.""" + # Trigger a refresh of our data, based on the updated device list + self._config_entry.async_create_background_task( + self.hass, self.async_request_refresh(), "Omada Firmware Update Refresh" + ) + async def async_setup_entry( hass: HomeAssistant, @@ -77,18 +97,21 @@ async def async_setup_entry( ) -> None: """Set up switches.""" controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] - omada_client = controller.omada_client - devices = await omada_client.get_devices() + devices = controller.devices_coordinator.data - coordinator = OmadaFirmwareUpdateCoodinator(hass, omada_client) + coordinator = OmadaFirmwareUpdateCoordinator( + hass, config_entry, controller.omada_client, controller.devices_coordinator + ) - async_add_entities(OmadaDeviceUpdate(coordinator, device) for device in devices) + async_add_entities( + OmadaDeviceUpdate(coordinator, device) for device in devices.values() + ) await coordinator.async_request_refresh() class OmadaDeviceUpdate( - OmadaDeviceEntity[OmadaFirmwareUpdateCoodinator], + OmadaDeviceEntity[OmadaFirmwareUpdateCoordinator], UpdateEntity, ): """Firmware update status for Omada SDN devices.""" @@ -103,7 +126,7 @@ class OmadaDeviceUpdate( def __init__( self, - coordinator: OmadaFirmwareUpdateCoodinator, + coordinator: OmadaFirmwareUpdateCoordinator, device: OmadaListDevice, ) -> None: """Initialize the update entity.""" diff --git a/tests/components/tplink_omada/test_init.py b/tests/components/tplink_omada/test_init.py new file mode 100644 index 00000000000..762168df9d6 --- /dev/null +++ b/tests/components/tplink_omada/test_init.py @@ -0,0 +1,47 @@ +"""Tests for TP-Link Omada integration init.""" + +from unittest.mock import MagicMock + +from homeassistant.components.tplink_omada.const import DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr + +from tests.common import MockConfigEntry + +MOCK_ENTRY_DATA = { + "host": "https://fake.omada.host", + "verify_ssl": True, + "site": "SiteId", + "username": "test-username", + "password": "test-password", +} + + +async def test_missing_devices_removed_at_startup( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_omada_client: MagicMock, +) -> None: + """Test missing devices are removed at startup.""" + mock_config_entry = MockConfigEntry( + title="Test Omada Controller", + domain=DOMAIN, + data=dict(MOCK_ENTRY_DATA), + unique_id="12345", + ) + mock_config_entry.add_to_hass(hass) + + device_entry = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, "AA:BB:CC:DD:EE:FF")}, + manufacturer="TPLink", + name="Old Device", + model="Some old model", + ) + + assert device_registry.async_get(device_entry.id) == device_entry + + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert device_registry.async_get(device_entry.id) is None