Automatically remove unregistered TP-Link Omada devices at start up (#124153)

* Adding coordinator for omada device list

* Remove dead omada devices at startup

* Tidy up tests

* Address PR feedback

* Returned to use of read-only properties for coordinators. Tidied up parameters some more

* Update homeassistant/components/tplink_omada/controller.py

* Update homeassistant/components/tplink_omada/controller.py

* Update homeassistant/components/tplink_omada/controller.py

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
MarkGodwin 2024-09-22 16:05:29 +01:00 committed by GitHub
parent 8158ca7c69
commit 2a36ec3e21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 164 additions and 51 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from tplink_omada_client import OmadaSite from tplink_omada_client import OmadaSite
from tplink_omada_client.devices import OmadaListDevice
from tplink_omada_client.exceptions import ( from tplink_omada_client.exceptions import (
ConnectionFailed, ConnectionFailed,
LoginFailed, LoginFailed,
@ -14,6 +15,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr
from .config_flow import CONF_SITE, create_omada_client from .config_flow import CONF_SITE, create_omada_client
from .const import DOMAIN from .const import DOMAIN
@ -52,13 +54,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
site_client = await client.get_site_client(OmadaSite("", entry.data[CONF_SITE])) site_client = await client.get_site_client(OmadaSite("", entry.data[CONF_SITE]))
controller = OmadaSiteController(hass, site_client) controller = OmadaSiteController(hass, site_client)
gateway_coordinator = await controller.get_gateway_coordinator() await controller.initialize_first_refresh()
if gateway_coordinator:
await gateway_coordinator.async_config_entry_first_refresh()
await controller.get_clients_coordinator().async_config_entry_first_refresh()
hass.data[DOMAIN][entry.entry_id] = controller hass.data[DOMAIN][entry.entry_id] = controller
_remove_old_devices(hass, entry, controller.devices_coordinator.data)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
@ -70,3 +71,20 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok return unload_ok
def _remove_old_devices(
hass: HomeAssistant, entry: ConfigEntry, omada_devices: dict[str, OmadaListDevice]
) -> None:
device_registry = dr.async_get(hass)
for registered_device in device_registry.devices.get_devices_for_config_entry_id(
entry.entry_id
):
mac = next(
(i[1] for i in registered_device.identifiers if i[0] == DOMAIN), None
)
if mac and mac not in omada_devices:
device_registry.async_update_device(
registered_device.id, remove_config_entry_id=entry.entry_id
)

View File

@ -34,7 +34,7 @@ async def async_setup_entry(
"""Set up binary sensors.""" """Set up binary sensors."""
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
gateway_coordinator = await controller.get_gateway_coordinator() gateway_coordinator = controller.gateway_coordinator
if not gateway_coordinator: if not gateway_coordinator:
return return

View File

@ -7,6 +7,7 @@ from homeassistant.core import HomeAssistant
from .coordinator import ( from .coordinator import (
OmadaClientsCoordinator, OmadaClientsCoordinator,
OmadaDevicesCoordinator,
OmadaGatewayCoordinator, OmadaGatewayCoordinator,
OmadaSwitchPortCoordinator, OmadaSwitchPortCoordinator,
) )
@ -16,15 +17,33 @@ class OmadaSiteController:
"""Controller for the Omada SDN site.""" """Controller for the Omada SDN site."""
_gateway_coordinator: OmadaGatewayCoordinator | None = None _gateway_coordinator: OmadaGatewayCoordinator | None = None
_initialized_gateway_coordinator = False
_clients_coordinator: OmadaClientsCoordinator | None = None
def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None: def __init__(
self,
hass: HomeAssistant,
omada_client: OmadaSiteClient,
) -> None:
"""Create the controller.""" """Create the controller."""
self._hass = hass self._hass = hass
self._omada_client = omada_client self._omada_client = omada_client
self._switch_port_coordinators: dict[str, OmadaSwitchPortCoordinator] = {} self._switch_port_coordinators: dict[str, OmadaSwitchPortCoordinator] = {}
self._devices_coordinator = OmadaDevicesCoordinator(hass, omada_client)
self._clients_coordinator = OmadaClientsCoordinator(hass, omada_client)
async def initialize_first_refresh(self) -> None:
"""Initialize the all coordinators, and perform first refresh."""
await self._devices_coordinator.async_config_entry_first_refresh()
devices = self._devices_coordinator.data.values()
gateway = next((d for d in devices if d.type == "gateway"), None)
if gateway:
self._gateway_coordinator = OmadaGatewayCoordinator(
self._hass, self._omada_client, gateway.mac
)
await self._gateway_coordinator.async_config_entry_first_refresh()
await self.clients_coordinator.async_config_entry_first_refresh()
@property @property
def omada_client(self) -> OmadaSiteClient: def omada_client(self) -> OmadaSiteClient:
@ -42,26 +61,17 @@ class OmadaSiteController:
return self._switch_port_coordinators[switch.mac] return self._switch_port_coordinators[switch.mac]
async def get_gateway_coordinator(self) -> OmadaGatewayCoordinator | None: @property
"""Get coordinator for site's gateway, or None if there is no gateway.""" def gateway_coordinator(self) -> OmadaGatewayCoordinator | None:
if not self._initialized_gateway_coordinator: """Gets the coordinator for site's gateway, or None if there is no gateway."""
self._initialized_gateway_coordinator = True
devices = await self._omada_client.get_devices()
gateway = next((d for d in devices if d.type == "gateway"), None)
if not gateway:
return None
self._gateway_coordinator = OmadaGatewayCoordinator(
self._hass, self._omada_client, gateway.mac
)
return self._gateway_coordinator return self._gateway_coordinator
def get_clients_coordinator(self) -> OmadaClientsCoordinator: @property
"""Get coordinator for site's clients.""" def devices_coordinator(self) -> OmadaDevicesCoordinator:
if not self._clients_coordinator: """Gets the coordinator for site's devices."""
self._clients_coordinator = OmadaClientsCoordinator( return self._devices_coordinator
self._hass, self._omada_client
)
@property
def clients_coordinator(self) -> OmadaClientsCoordinator:
"""Gets the coordinator for site's clients."""
return self._clients_coordinator return self._clients_coordinator

View File

@ -6,7 +6,7 @@ import logging
from tplink_omada_client import OmadaSiteClient, OmadaSwitchPortDetails from tplink_omada_client import OmadaSiteClient, OmadaSwitchPortDetails
from tplink_omada_client.clients import OmadaWirelessClient from tplink_omada_client.clients import OmadaWirelessClient
from tplink_omada_client.devices import OmadaGateway, OmadaSwitch from tplink_omada_client.devices import OmadaGateway, OmadaListDevice, OmadaSwitch
from tplink_omada_client.exceptions import OmadaClientException from tplink_omada_client.exceptions import OmadaClientException
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
POLL_SWITCH_PORT = 300 POLL_SWITCH_PORT = 300
POLL_GATEWAY = 300 POLL_GATEWAY = 300
POLL_CLIENTS = 300 POLL_CLIENTS = 300
POLL_DEVICES = 900
class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]): class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]):
@ -27,14 +28,14 @@ class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]):
hass: HomeAssistant, hass: HomeAssistant,
omada_client: OmadaSiteClient, omada_client: OmadaSiteClient,
name: str, name: str,
poll_delay: int = 300, poll_delay: int | None = 300,
) -> None: ) -> None:
"""Initialize my coordinator.""" """Initialize my coordinator."""
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
name=f"Omada API Data - {name}", name=f"Omada API Data - {name}",
update_interval=timedelta(seconds=poll_delay), update_interval=timedelta(seconds=poll_delay) if poll_delay else None,
) )
self.omada_client = omada_client self.omada_client = omada_client
@ -91,6 +92,22 @@ class OmadaGatewayCoordinator(OmadaCoordinator[OmadaGateway]):
return {self.mac: gateway} return {self.mac: gateway}
class OmadaDevicesCoordinator(OmadaCoordinator[OmadaListDevice]):
"""Coordinator for generic device lists from the controller."""
def __init__(
self,
hass: HomeAssistant,
omada_client: OmadaSiteClient,
) -> None:
"""Initialize my coordinator."""
super().__init__(hass, omada_client, "DeviceList", POLL_CLIENTS)
async def poll_update(self) -> dict[str, OmadaListDevice]:
"""Poll the site's current registered Omada devices."""
return {d.mac: d for d in await self.omada_client.get_devices()}
class OmadaClientsCoordinator(OmadaCoordinator[OmadaWirelessClient]): class OmadaClientsCoordinator(OmadaCoordinator[OmadaWirelessClient]):
"""Coordinator for getting details about the site's connected clients.""" """Coordinator for getting details about the site's connected clients."""

View File

@ -26,7 +26,6 @@ async def async_setup_entry(
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
clients_coordinator = controller.get_clients_coordinator()
site_id = config_entry.data[CONF_SITE] site_id = config_entry.data[CONF_SITE]
# Add all known WiFi devices as potentially tracked devices. They will only be # Add all known WiFi devices as potentially tracked devices. They will only be
@ -34,7 +33,7 @@ async def async_setup_entry(
async_add_entities( async_add_entities(
[ [
OmadaClientScannerEntity( OmadaClientScannerEntity(
site_id, client.mac, client.name, clients_coordinator site_id, client.mac, client.name, controller.clients_coordinator
) )
async for client in controller.omada_client.get_known_clients() async for client in controller.omada_client.get_known_clients()
if isinstance(client, OmadaWirelessClient) if isinstance(client, OmadaWirelessClient)

View File

@ -5,7 +5,6 @@ from typing import Any
from tplink_omada_client.devices import OmadaDevice from tplink_omada_client.devices import OmadaDevice
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN from .const import DOMAIN
@ -19,7 +18,7 @@ class OmadaDeviceEntity[_T: OmadaCoordinator[Any]](CoordinatorEntity[_T]):
"""Initialize the device.""" """Initialize the device."""
super().__init__(coordinator) super().__init__(coordinator)
self.device = device self.device = device
self._attr_device_info = DeviceInfo( self._attr_device_info = dr.DeviceInfo(
connections={(dr.CONNECTION_NETWORK_MAC, device.mac)}, connections={(dr.CONNECTION_NETWORK_MAC, device.mac)},
identifiers={(DOMAIN, device.mac)}, identifiers={(DOMAIN, device.mac)},
manufacturer="TP-Link", manufacturer="TP-Link",

View File

@ -74,7 +74,7 @@ async def async_setup_entry(
if desc.exists_func(switch, port) if desc.exists_func(switch, port)
) )
gateway_coordinator = await controller.get_gateway_coordinator() gateway_coordinator = controller.gateway_coordinator
if gateway_coordinator: if gateway_coordinator:
for gateway in gateway_coordinator.data.values(): for gateway in gateway_coordinator.data.values():
entities.extend( entities.extend(

View File

@ -21,10 +21,9 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN from .const import DOMAIN
from .controller import OmadaSiteController from .controller import OmadaSiteController
from .coordinator import OmadaCoordinator from .coordinator import POLL_DEVICES, OmadaCoordinator, OmadaDevicesCoordinator
from .entity import OmadaDeviceEntity from .entity import OmadaDeviceEntity
POLL_DELAY_IDLE = 6 * 60 * 60
POLL_DELAY_UPGRADE = 60 POLL_DELAY_UPGRADE = 60
@ -35,15 +34,28 @@ class FirmwareUpdateStatus(NamedTuple):
firmware: OmadaFirmwareUpdate | None firmware: OmadaFirmwareUpdate | None
class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module class OmadaFirmwareUpdateCoordinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module
"""Coordinator for getting details about ports on a switch.""" """Coordinator for getting details about available firmware updates for Omada devices."""
def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None: def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
omada_client: OmadaSiteClient,
devices_coordinator: OmadaDevicesCoordinator,
) -> None:
"""Initialize my coordinator.""" """Initialize my coordinator."""
super().__init__(hass, omada_client, "Firmware Updates", POLL_DELAY_IDLE) super().__init__(hass, omada_client, "Firmware Updates", poll_delay=None)
self._devices_coordinator = devices_coordinator
self._config_entry = config_entry
config_entry.async_on_unload(
devices_coordinator.async_add_listener(self._handle_devices_update)
)
async def _get_firmware_updates(self) -> list[FirmwareUpdateStatus]: async def _get_firmware_updates(self) -> list[FirmwareUpdateStatus]:
devices = await self.omada_client.get_devices() devices = self._devices_coordinator.data.values()
updates = [ updates = [
FirmwareUpdateStatus( FirmwareUpdateStatus(
@ -55,12 +67,12 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): #
for d in devices for d in devices
] ]
# During a firmware upgrade, poll more frequently # During a firmware upgrade, poll device list more frequently
self.update_interval = timedelta( self._devices_coordinator.update_interval = timedelta(
seconds=( seconds=(
POLL_DELAY_UPGRADE POLL_DELAY_UPGRADE
if any(u.device.fw_download for u in updates) if any(u.device.fw_download for u in updates)
else POLL_DELAY_IDLE else POLL_DEVICES
) )
) )
return updates return updates
@ -69,6 +81,14 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): #
"""Poll the state of Omada Devices firmware update availability.""" """Poll the state of Omada Devices firmware update availability."""
return {d.device.mac: d for d in await self._get_firmware_updates()} return {d.device.mac: d for d in await self._get_firmware_updates()}
@callback
def _handle_devices_update(self) -> None:
"""Handle updated data from the devices coordinator."""
# Trigger a refresh of our data, based on the updated device list
self._config_entry.async_create_background_task(
self.hass, self.async_request_refresh(), "Omada Firmware Update Refresh"
)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
@ -77,18 +97,21 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up switches.""" """Set up switches."""
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id] controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
omada_client = controller.omada_client
devices = await omada_client.get_devices() devices = controller.devices_coordinator.data
coordinator = OmadaFirmwareUpdateCoodinator(hass, omada_client) coordinator = OmadaFirmwareUpdateCoordinator(
hass, config_entry, controller.omada_client, controller.devices_coordinator
)
async_add_entities(OmadaDeviceUpdate(coordinator, device) for device in devices) async_add_entities(
OmadaDeviceUpdate(coordinator, device) for device in devices.values()
)
await coordinator.async_request_refresh() await coordinator.async_request_refresh()
class OmadaDeviceUpdate( class OmadaDeviceUpdate(
OmadaDeviceEntity[OmadaFirmwareUpdateCoodinator], OmadaDeviceEntity[OmadaFirmwareUpdateCoordinator],
UpdateEntity, UpdateEntity,
): ):
"""Firmware update status for Omada SDN devices.""" """Firmware update status for Omada SDN devices."""
@ -103,7 +126,7 @@ class OmadaDeviceUpdate(
def __init__( def __init__(
self, self,
coordinator: OmadaFirmwareUpdateCoodinator, coordinator: OmadaFirmwareUpdateCoordinator,
device: OmadaListDevice, device: OmadaListDevice,
) -> None: ) -> None:
"""Initialize the update entity.""" """Initialize the update entity."""

View File

@ -0,0 +1,47 @@
"""Tests for TP-Link Omada integration init."""
from unittest.mock import MagicMock
from homeassistant.components.tplink_omada.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from tests.common import MockConfigEntry
MOCK_ENTRY_DATA = {
"host": "https://fake.omada.host",
"verify_ssl": True,
"site": "SiteId",
"username": "test-username",
"password": "test-password",
}
async def test_missing_devices_removed_at_startup(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_omada_client: MagicMock,
) -> None:
"""Test missing devices are removed at startup."""
mock_config_entry = MockConfigEntry(
title="Test Omada Controller",
domain=DOMAIN,
data=dict(MOCK_ENTRY_DATA),
unique_id="12345",
)
mock_config_entry.add_to_hass(hass)
device_entry = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, "AA:BB:CC:DD:EE:FF")},
manufacturer="TPLink",
name="Old Device",
model="Some old model",
)
assert device_registry.async_get(device_entry.id) == device_entry
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert device_registry.async_get(device_entry.id) is None