diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index ea8db77e124..2dddfeba304 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -1,22 +1,32 @@ """Track both clients and devices using UniFi Network.""" +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass from datetime import timedelta import logging +from typing import Generic, TypeVar +import aiounifi +from aiounifi.interfaces.api_handlers import ItemEvent +from aiounifi.interfaces.devices import Devices from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT -from aiounifi.models.event import EventKey +from aiounifi.models.device import Device +from aiounifi.models.event import Event, EventKey from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController from .unifi_client import UniFiClientBase -from .unifi_entity_base import UniFiBase LOGGER = logging.getLogger(__name__) @@ -59,6 +69,61 @@ WIRELESS_CONNECTION = ( ) +_DataT = TypeVar("_DataT", bound=Device) +_HandlerT = TypeVar("_HandlerT", bound=Devices) + + +@callback +def async_device_available_fn(controller: UniFiController, obj_id: str) -> bool: + """Check if device object is disabled.""" + device = controller.api.devices[obj_id] + return controller.available and not device.disabled + + +@dataclass +class UnifiEntityLoader(Generic[_HandlerT, _DataT]): + """Validate and load entities from different UniFi handlers.""" + + allowed_fn: Callable[[UniFiController, str], bool] + api_handler_fn: Callable[[aiounifi.Controller], _HandlerT] + available_fn: Callable[[UniFiController, str], bool] + event_is_on: tuple[EventKey, ...] | None + event_to_subscribe: tuple[EventKey, ...] | None + is_connected_fn: Callable[[aiounifi.Controller, _DataT], bool] + name_fn: Callable[[_DataT], str | None] + object_fn: Callable[[aiounifi.Controller, str], _DataT] + supported_fn: Callable[[aiounifi.Controller, str], bool | None] + unique_id_fn: Callable[[str], str] + ip_address_fn: Callable[[aiounifi.Controller, str], str] + hostname_fn: Callable[[aiounifi.Controller, str], str | None] + + +@dataclass +class UnifiEntityDescription(EntityDescription, UnifiEntityLoader[_HandlerT, _DataT]): + """Class describing UniFi switch entity.""" + + +ENTITY_DESCRIPTIONS: tuple[UnifiEntityDescription, ...] = ( + UnifiEntityDescription[Devices, Device]( + key="Device scanner", + has_entity_name=True, + icon="mdi:ethernet", + allowed_fn=lambda controller, obj_id: controller.option_track_devices, + api_handler_fn=lambda api: api.devices, + available_fn=async_device_available_fn, + event_is_on=None, + event_to_subscribe=None, + is_connected_fn=lambda api, device: device.state == 1, + name_fn=lambda device: device.name or device.model, + object_fn=lambda api, obj_id: api.devices[obj_id], + supported_fn=lambda api, obj_id: True, + unique_id_fn=lambda obj_id: obj_id, + ip_address_fn=lambda api, obj_id: api.devices[obj_id].ip, + hostname_fn=lambda api, obj_id: None, + ), +) + + async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, @@ -76,9 +141,6 @@ async def async_setup_entry( if controller.option_track_clients: add_client_entities(controller, async_add_entities, clients) - if controller.option_track_devices: - add_device_entities(controller, async_add_entities, devices) - for signal in (controller.signal_update, controller.signal_options_update): config_entry.async_on_unload( async_dispatcher_connect(hass, signal, items_added) @@ -86,6 +148,35 @@ async def async_setup_entry( items_added() + @callback + def async_load_entities(description: UnifiEntityDescription) -> None: + """Load and subscribe to UniFi devices.""" + entities: list[ScannerEntity] = [] + api_handler = description.api_handler_fn(controller.api) + + @callback + def async_create_entity(event: ItemEvent, obj_id: str) -> None: + """Create UniFi entity.""" + if not description.allowed_fn( + controller, obj_id + ) or not description.supported_fn(controller.api, obj_id): + return + + entity = UnifiScannerEntity(obj_id, controller, description) + if event == ItemEvent.ADDED: + async_add_entities([entity]) + return + entities.append(entity) + + for obj_id in api_handler: + async_create_entity(ItemEvent.CHANGED, obj_id) + async_add_entities(entities) + + api_handler.subscribe(async_create_entity, ItemEvent.ADDED) + + for description in ENTITY_DESCRIPTIONS: + async_load_entities(description) + @callback def add_client_entities(controller, async_add_entities, clients): @@ -113,21 +204,6 @@ def add_client_entities(controller, async_add_entities, clients): async_add_entities(trackers) -@callback -def add_device_entities(controller, async_add_entities, devices): - """Add new device tracker entities from the controller.""" - trackers = [] - - for mac in devices: - if mac in controller.entities[DOMAIN][UniFiDeviceTracker.TYPE]: - continue - - device = controller.api.devices[mac] - trackers.append(UniFiDeviceTracker(device, controller)) - - async_add_entities(trackers) - - class UniFiClientTracker(UniFiClientBase, ScannerEntity): """Representation of a network client.""" @@ -313,46 +389,119 @@ class UniFiClientTracker(UniFiClientBase, ScannerEntity): await self.remove_item({self.client.mac}) -class UniFiDeviceTracker(UniFiBase, ScannerEntity): - """Representation of a network infrastructure device.""" +class UnifiScannerEntity(ScannerEntity, Generic[_HandlerT, _DataT]): + """Representation of a UniFi scanner.""" - DOMAIN = DOMAIN - TYPE = DEVICE_TRACKER + entity_description: UnifiEntityDescription[_HandlerT, _DataT] + _attr_should_poll = False - def __init__(self, device, controller): - """Set up tracked device.""" - super().__init__(device, controller) + def __init__( + self, + obj_id: str, + controller: UniFiController, + description: UnifiEntityDescription[_HandlerT, _DataT], + ) -> None: + """Set up UniFi scanner entity.""" + self._obj_id = obj_id + self.controller = controller + self.entity_description = description - self.device = self._item - self._is_connected = device.state == 1 self._controller_connection_state_changed = False + self._removed = False self.schedule_update = False + self._attr_available = description.available_fn(controller, obj_id) + self._attr_unique_id: str = description.unique_id_fn(obj_id) + + obj = description.object_fn(controller.api, obj_id) + self._is_connected = description.is_connected_fn(controller.api, obj) + self._attr_name = description.name_fn(obj) + + @property + def is_connected(self): + """Return true if the device is connected to the network.""" + return self._is_connected + + @property + def hostname(self) -> str | None: + """Return hostname of the device.""" + return self.entity_description.hostname_fn(self.controller.api, self._obj_id) + + @property + def ip_address(self) -> str: + """Return the primary ip address of the device.""" + return self.entity_description.ip_address_fn(self.controller.api, self._obj_id) + + @property + def mac_address(self) -> str: + """Return the mac address of the device.""" + return self._obj_id + + @property + def source_type(self) -> SourceType: + """Return the source type, eg gps or router, of the device.""" + return SourceType.ROUTER + + @property + def unique_id(self) -> str: + """Return a unique ID.""" + return self._attr_unique_id + async def async_added_to_hass(self) -> None: - """Watch object when added.""" + """Register callbacks.""" + description = self.entity_description + handler = description.api_handler_fn(self.controller.api) + self.async_on_remove( + handler.subscribe( + self.async_signalling_callback, + ) + ) self.async_on_remove( async_dispatcher_connect( self.hass, - f"{self.controller.signal_heartbeat_missed}_{self.unique_id}", + self.controller.signal_reachable, + self.async_signal_reachable_callback, + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, + self.controller.signal_options_update, + self.options_updated, + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, + self.controller.signal_remove, + self.remove_item, + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, + f"{self.controller.signal_heartbeat_missed}_{self._obj_id}", self._make_disconnected, ) ) - await super().async_added_to_hass() - - async def async_will_remove_from_hass(self) -> None: - """Disconnect object when removed.""" - self.controller.async_heartbeat(self.unique_id) - await super().async_will_remove_from_hass() @callback - def async_signal_reachable_callback(self) -> None: - """Call when controller connection state change.""" - self._controller_connection_state_changed = True - super().async_signal_reachable_callback() + def _make_disconnected(self, *_): + """No heart beat by device.""" + self._is_connected = False + self.async_write_ha_state() @callback - def async_update_callback(self) -> None: - """Update the devices' state.""" + def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None: + """Update the switch state.""" + if event == ItemEvent.DELETED and obj_id == self._obj_id: + self.hass.async_create_task(self.remove_item({self._obj_id})) + return + + description = self.entity_description + if not description.supported_fn(self.controller.api, self._obj_id): + self.hass.async_create_task(self.remove_item({self._obj_id})) + return if self._controller_connection_state_changed: self._controller_connection_state_changed = False @@ -363,81 +512,54 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity): else: self.controller.async_heartbeat(self.unique_id) - - elif self.device.last_updated == SOURCE_DATA: + else: self._is_connected = True self.schedule_update = True if self.schedule_update: + device = self.entity_description.object_fn( + self.controller.api, self._obj_id + ) self.schedule_update = False self.controller.async_heartbeat( self.unique_id, - dt_util.utcnow() + timedelta(seconds=self.device.next_interval + 60), + dt_util.utcnow() + timedelta(seconds=device.next_interval + 60), ) - super().async_update_callback() - - @callback - def _make_disconnected(self, *_): - """No heart beat by device.""" - self._is_connected = False + self._attr_available = description.available_fn(self.controller, self._obj_id) self.async_write_ha_state() - @property - def is_connected(self): - """Return true if the device is connected to the network.""" - return self._is_connected + @callback + def async_signal_reachable_callback(self) -> None: + """Call when controller connection state change.""" + self.async_signalling_callback(ItemEvent.ADDED, self._obj_id) - @property - def source_type(self) -> SourceType: - """Return the source type of the device.""" - return SourceType.ROUTER + @callback + def async_event_callback(self, event: Event) -> None: + """Event subscription callback.""" + if event.mac != self._obj_id: + return - @property - def name(self) -> str: - """Return the name of the device.""" - return self.device.name or self.device.model + description = self.entity_description + assert isinstance(description.event_to_subscribe, tuple) + assert isinstance(description.event_is_on, tuple) - @property - def unique_id(self) -> str: - """Return a unique identifier for this device.""" - return self.device.mac - - @property - def available(self) -> bool: - """Return if controller is available.""" - return not self.device.disabled and self.controller.available - - @property - def extra_state_attributes(self): - """Return the device state attributes.""" - if self.device.state == 0: - return {} - - attributes = {} - - if self.device.has_fan: - attributes["fan_level"] = self.device.fan_level - - if self.device.overheating: - attributes["overheating"] = self.device.overheating - - if self.device.upgradable: - attributes["upgradable"] = self.device.upgradable - - return attributes - - @property - def ip_address(self) -> str: - """Return the primary ip address of the device.""" - return self.device.ip - - @property - def mac_address(self) -> str: - """Return the mac address of the device.""" - return self.device.mac + if event.key in description.event_to_subscribe: + self._is_connected = event.key in description.event_is_on + self._attr_available = description.available_fn(self.controller, self._obj_id) + self.async_write_ha_state() async def options_updated(self) -> None: """Config entry options are updated, remove entity if option is disabled.""" - if not self.controller.option_track_devices: - await self.remove_item({self.device.mac}) + if not self.entity_description.allowed_fn(self.controller, self._obj_id): + await self.remove_item({self._obj_id}) + + async def remove_item(self, keys: set) -> None: + """Remove entity if object ID is part of set.""" + if self._obj_id not in keys or self._removed: + return + self._removed = True + if self.registry_entry: + er.async_get(self.hass).async_remove(self.entity_id) + else: + await self.async_remove(force_remove=True) diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index b8f1aa771a4..596ecd46cd3 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -607,15 +607,6 @@ async def test_option_track_devices(hass, aioclient_mock, mock_device_registry): assert hass.states.get("device_tracker.client") assert not hass.states.get("device_tracker.device") - hass.config_entries.async_update_entry( - config_entry, - options={CONF_TRACK_DEVICES: True}, - ) - await hass.async_block_till_done() - - assert hass.states.get("device_tracker.client") - assert hass.states.get("device_tracker.device") - async def test_option_ssid_filter( hass, aioclient_mock, mock_unifi_websocket, mock_device_registry @@ -1007,7 +998,7 @@ async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry): "version": "4.0.42.10433", } - config_entry = await setup_unifi_integration( + await setup_unifi_integration( hass, aioclient_mock, options={CONF_TRACK_DEVICES: False}, @@ -1019,16 +1010,6 @@ async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry): assert hass.states.get("device_tracker.client") assert not hass.states.get("device_tracker.device") - hass.config_entries.async_update_entry( - config_entry, - options={CONF_TRACK_DEVICES: True}, - ) - await hass.async_block_till_done() - - assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 2 - assert hass.states.get("device_tracker.client") - assert hass.states.get("device_tracker.device") - async def test_dont_track_wired_clients(hass, aioclient_mock, mock_device_registry): """Test don't track wired clients config works."""