Rework UniFi device tracker to utilizing entity description (#81979)

* Rework UniFi device tracker to utilizing entity description

* Use bound

* Fix review comments from other PR
This commit is contained in:
Robert Svensson 2022-12-28 22:29:11 +01:00 committed by GitHub
parent cbcfeee322
commit de5c7b0414
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 227 additions and 124 deletions

View File

@ -1,22 +1,32 @@
"""Track both clients and devices using UniFi Network.""" """Track both clients and devices using UniFi Network."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Generic, TypeVar
import aiounifi
from aiounifi.interfaces.api_handlers import ItemEvent
from aiounifi.interfaces.devices import Devices
from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT
from aiounifi.models.event import EventKey from aiounifi.models.device import Device
from aiounifi.models.event import Event, EventKey
from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import EntityDescription
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import DOMAIN as UNIFI_DOMAIN from .const import DOMAIN as UNIFI_DOMAIN
from .controller import UniFiController from .controller import UniFiController
from .unifi_client import UniFiClientBase from .unifi_client import UniFiClientBase
from .unifi_entity_base import UniFiBase
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -59,6 +69,61 @@ WIRELESS_CONNECTION = (
) )
_DataT = TypeVar("_DataT", bound=Device)
_HandlerT = TypeVar("_HandlerT", bound=Devices)
@callback
def async_device_available_fn(controller: UniFiController, obj_id: str) -> bool:
"""Check if device object is disabled."""
device = controller.api.devices[obj_id]
return controller.available and not device.disabled
@dataclass
class UnifiEntityLoader(Generic[_HandlerT, _DataT]):
"""Validate and load entities from different UniFi handlers."""
allowed_fn: Callable[[UniFiController, str], bool]
api_handler_fn: Callable[[aiounifi.Controller], _HandlerT]
available_fn: Callable[[UniFiController, str], bool]
event_is_on: tuple[EventKey, ...] | None
event_to_subscribe: tuple[EventKey, ...] | None
is_connected_fn: Callable[[aiounifi.Controller, _DataT], bool]
name_fn: Callable[[_DataT], str | None]
object_fn: Callable[[aiounifi.Controller, str], _DataT]
supported_fn: Callable[[aiounifi.Controller, str], bool | None]
unique_id_fn: Callable[[str], str]
ip_address_fn: Callable[[aiounifi.Controller, str], str]
hostname_fn: Callable[[aiounifi.Controller, str], str | None]
@dataclass
class UnifiEntityDescription(EntityDescription, UnifiEntityLoader[_HandlerT, _DataT]):
"""Class describing UniFi switch entity."""
ENTITY_DESCRIPTIONS: tuple[UnifiEntityDescription, ...] = (
UnifiEntityDescription[Devices, Device](
key="Device scanner",
has_entity_name=True,
icon="mdi:ethernet",
allowed_fn=lambda controller, obj_id: controller.option_track_devices,
api_handler_fn=lambda api: api.devices,
available_fn=async_device_available_fn,
event_is_on=None,
event_to_subscribe=None,
is_connected_fn=lambda api, device: device.state == 1,
name_fn=lambda device: device.name or device.model,
object_fn=lambda api, obj_id: api.devices[obj_id],
supported_fn=lambda api, obj_id: True,
unique_id_fn=lambda obj_id: obj_id,
ip_address_fn=lambda api, obj_id: api.devices[obj_id].ip,
hostname_fn=lambda api, obj_id: None,
),
)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: ConfigEntry,
@ -76,9 +141,6 @@ async def async_setup_entry(
if controller.option_track_clients: if controller.option_track_clients:
add_client_entities(controller, async_add_entities, clients) add_client_entities(controller, async_add_entities, clients)
if controller.option_track_devices:
add_device_entities(controller, async_add_entities, devices)
for signal in (controller.signal_update, controller.signal_options_update): for signal in (controller.signal_update, controller.signal_options_update):
config_entry.async_on_unload( config_entry.async_on_unload(
async_dispatcher_connect(hass, signal, items_added) async_dispatcher_connect(hass, signal, items_added)
@ -86,6 +148,35 @@ async def async_setup_entry(
items_added() items_added()
@callback
def async_load_entities(description: UnifiEntityDescription) -> None:
"""Load and subscribe to UniFi devices."""
entities: list[ScannerEntity] = []
api_handler = description.api_handler_fn(controller.api)
@callback
def async_create_entity(event: ItemEvent, obj_id: str) -> None:
"""Create UniFi entity."""
if not description.allowed_fn(
controller, obj_id
) or not description.supported_fn(controller.api, obj_id):
return
entity = UnifiScannerEntity(obj_id, controller, description)
if event == ItemEvent.ADDED:
async_add_entities([entity])
return
entities.append(entity)
for obj_id in api_handler:
async_create_entity(ItemEvent.CHANGED, obj_id)
async_add_entities(entities)
api_handler.subscribe(async_create_entity, ItemEvent.ADDED)
for description in ENTITY_DESCRIPTIONS:
async_load_entities(description)
@callback @callback
def add_client_entities(controller, async_add_entities, clients): def add_client_entities(controller, async_add_entities, clients):
@ -113,21 +204,6 @@ def add_client_entities(controller, async_add_entities, clients):
async_add_entities(trackers) async_add_entities(trackers)
@callback
def add_device_entities(controller, async_add_entities, devices):
"""Add new device tracker entities from the controller."""
trackers = []
for mac in devices:
if mac in controller.entities[DOMAIN][UniFiDeviceTracker.TYPE]:
continue
device = controller.api.devices[mac]
trackers.append(UniFiDeviceTracker(device, controller))
async_add_entities(trackers)
class UniFiClientTracker(UniFiClientBase, ScannerEntity): class UniFiClientTracker(UniFiClientBase, ScannerEntity):
"""Representation of a network client.""" """Representation of a network client."""
@ -313,46 +389,119 @@ class UniFiClientTracker(UniFiClientBase, ScannerEntity):
await self.remove_item({self.client.mac}) await self.remove_item({self.client.mac})
class UniFiDeviceTracker(UniFiBase, ScannerEntity): class UnifiScannerEntity(ScannerEntity, Generic[_HandlerT, _DataT]):
"""Representation of a network infrastructure device.""" """Representation of a UniFi scanner."""
DOMAIN = DOMAIN entity_description: UnifiEntityDescription[_HandlerT, _DataT]
TYPE = DEVICE_TRACKER _attr_should_poll = False
def __init__(self, device, controller): def __init__(
"""Set up tracked device.""" self,
super().__init__(device, controller) obj_id: str,
controller: UniFiController,
description: UnifiEntityDescription[_HandlerT, _DataT],
) -> None:
"""Set up UniFi scanner entity."""
self._obj_id = obj_id
self.controller = controller
self.entity_description = description
self.device = self._item
self._is_connected = device.state == 1
self._controller_connection_state_changed = False self._controller_connection_state_changed = False
self._removed = False
self.schedule_update = False self.schedule_update = False
self._attr_available = description.available_fn(controller, obj_id)
self._attr_unique_id: str = description.unique_id_fn(obj_id)
obj = description.object_fn(controller.api, obj_id)
self._is_connected = description.is_connected_fn(controller.api, obj)
self._attr_name = description.name_fn(obj)
@property
def is_connected(self):
"""Return true if the device is connected to the network."""
return self._is_connected
@property
def hostname(self) -> str | None:
"""Return hostname of the device."""
return self.entity_description.hostname_fn(self.controller.api, self._obj_id)
@property
def ip_address(self) -> str:
"""Return the primary ip address of the device."""
return self.entity_description.ip_address_fn(self.controller.api, self._obj_id)
@property
def mac_address(self) -> str:
"""Return the mac address of the device."""
return self._obj_id
@property
def source_type(self) -> SourceType:
"""Return the source type, eg gps or router, of the device."""
return SourceType.ROUTER
@property
def unique_id(self) -> str:
"""Return a unique ID."""
return self._attr_unique_id
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Watch object when added.""" """Register callbacks."""
description = self.entity_description
handler = description.api_handler_fn(self.controller.api)
self.async_on_remove(
handler.subscribe(
self.async_signalling_callback,
)
)
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( async_dispatcher_connect(
self.hass, self.hass,
f"{self.controller.signal_heartbeat_missed}_{self.unique_id}", self.controller.signal_reachable,
self.async_signal_reachable_callback,
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
self.controller.signal_options_update,
self.options_updated,
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
self.controller.signal_remove,
self.remove_item,
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
f"{self.controller.signal_heartbeat_missed}_{self._obj_id}",
self._make_disconnected, self._make_disconnected,
) )
) )
await super().async_added_to_hass()
async def async_will_remove_from_hass(self) -> None:
"""Disconnect object when removed."""
self.controller.async_heartbeat(self.unique_id)
await super().async_will_remove_from_hass()
@callback @callback
def async_signal_reachable_callback(self) -> None: def _make_disconnected(self, *_):
"""Call when controller connection state change.""" """No heart beat by device."""
self._controller_connection_state_changed = True self._is_connected = False
super().async_signal_reachable_callback() self.async_write_ha_state()
@callback @callback
def async_update_callback(self) -> None: def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None:
"""Update the devices' state.""" """Update the switch state."""
if event == ItemEvent.DELETED and obj_id == self._obj_id:
self.hass.async_create_task(self.remove_item({self._obj_id}))
return
description = self.entity_description
if not description.supported_fn(self.controller.api, self._obj_id):
self.hass.async_create_task(self.remove_item({self._obj_id}))
return
if self._controller_connection_state_changed: if self._controller_connection_state_changed:
self._controller_connection_state_changed = False self._controller_connection_state_changed = False
@ -363,81 +512,54 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity):
else: else:
self.controller.async_heartbeat(self.unique_id) self.controller.async_heartbeat(self.unique_id)
else:
elif self.device.last_updated == SOURCE_DATA:
self._is_connected = True self._is_connected = True
self.schedule_update = True self.schedule_update = True
if self.schedule_update: if self.schedule_update:
device = self.entity_description.object_fn(
self.controller.api, self._obj_id
)
self.schedule_update = False self.schedule_update = False
self.controller.async_heartbeat( self.controller.async_heartbeat(
self.unique_id, self.unique_id,
dt_util.utcnow() + timedelta(seconds=self.device.next_interval + 60), dt_util.utcnow() + timedelta(seconds=device.next_interval + 60),
) )
super().async_update_callback() self._attr_available = description.available_fn(self.controller, self._obj_id)
@callback
def _make_disconnected(self, *_):
"""No heart beat by device."""
self._is_connected = False
self.async_write_ha_state() self.async_write_ha_state()
@property @callback
def is_connected(self): def async_signal_reachable_callback(self) -> None:
"""Return true if the device is connected to the network.""" """Call when controller connection state change."""
return self._is_connected self.async_signalling_callback(ItemEvent.ADDED, self._obj_id)
@property @callback
def source_type(self) -> SourceType: def async_event_callback(self, event: Event) -> None:
"""Return the source type of the device.""" """Event subscription callback."""
return SourceType.ROUTER if event.mac != self._obj_id:
return
@property description = self.entity_description
def name(self) -> str: assert isinstance(description.event_to_subscribe, tuple)
"""Return the name of the device.""" assert isinstance(description.event_is_on, tuple)
return self.device.name or self.device.model
@property if event.key in description.event_to_subscribe:
def unique_id(self) -> str: self._is_connected = event.key in description.event_is_on
"""Return a unique identifier for this device.""" self._attr_available = description.available_fn(self.controller, self._obj_id)
return self.device.mac self.async_write_ha_state()
@property
def available(self) -> bool:
"""Return if controller is available."""
return not self.device.disabled and self.controller.available
@property
def extra_state_attributes(self):
"""Return the device state attributes."""
if self.device.state == 0:
return {}
attributes = {}
if self.device.has_fan:
attributes["fan_level"] = self.device.fan_level
if self.device.overheating:
attributes["overheating"] = self.device.overheating
if self.device.upgradable:
attributes["upgradable"] = self.device.upgradable
return attributes
@property
def ip_address(self) -> str:
"""Return the primary ip address of the device."""
return self.device.ip
@property
def mac_address(self) -> str:
"""Return the mac address of the device."""
return self.device.mac
async def options_updated(self) -> None: async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled.""" """Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_track_devices: if not self.entity_description.allowed_fn(self.controller, self._obj_id):
await self.remove_item({self.device.mac}) await self.remove_item({self._obj_id})
async def remove_item(self, keys: set) -> None:
"""Remove entity if object ID is part of set."""
if self._obj_id not in keys or self._removed:
return
self._removed = True
if self.registry_entry:
er.async_get(self.hass).async_remove(self.entity_id)
else:
await self.async_remove(force_remove=True)

View File

@ -607,15 +607,6 @@ async def test_option_track_devices(hass, aioclient_mock, mock_device_registry):
assert hass.states.get("device_tracker.client") assert hass.states.get("device_tracker.client")
assert not hass.states.get("device_tracker.device") assert not hass.states.get("device_tracker.device")
hass.config_entries.async_update_entry(
config_entry,
options={CONF_TRACK_DEVICES: True},
)
await hass.async_block_till_done()
assert hass.states.get("device_tracker.client")
assert hass.states.get("device_tracker.device")
async def test_option_ssid_filter( async def test_option_ssid_filter(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
@ -1007,7 +998,7 @@ async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry):
"version": "4.0.42.10433", "version": "4.0.42.10433",
} }
config_entry = await setup_unifi_integration( await setup_unifi_integration(
hass, hass,
aioclient_mock, aioclient_mock,
options={CONF_TRACK_DEVICES: False}, options={CONF_TRACK_DEVICES: False},
@ -1019,16 +1010,6 @@ async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry):
assert hass.states.get("device_tracker.client") assert hass.states.get("device_tracker.client")
assert not hass.states.get("device_tracker.device") assert not hass.states.get("device_tracker.device")
hass.config_entries.async_update_entry(
config_entry,
options={CONF_TRACK_DEVICES: True},
)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 2
assert hass.states.get("device_tracker.client")
assert hass.states.get("device_tracker.device")
async def test_dont_track_wired_clients(hass, aioclient_mock, mock_device_registry): async def test_dont_track_wired_clients(hass, aioclient_mock, mock_device_registry):
"""Test don't track wired clients config works.""" """Test don't track wired clients config works."""