diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index 2dddfeba304..fe323518bc0 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -13,19 +13,18 @@ from aiounifi.interfaces.api_handlers import ItemEvent from aiounifi.interfaces.devices import Devices from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT from aiounifi.models.device import Device -from aiounifi.models.event import Event, EventKey +from aiounifi.models.event import EventKey from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController +from .entity import UnifiEntity, UnifiEntityDescription from .unifi_client import UniFiClientBase LOGGER = logging.getLogger(__name__) @@ -80,44 +79,50 @@ def async_device_available_fn(controller: UniFiController, obj_id: str) -> bool: return controller.available and not device.disabled -@dataclass -class UnifiEntityLoader(Generic[_HandlerT, _DataT]): - """Validate and load entities from different UniFi handlers.""" +@callback +def async_device_heartbeat_timedelta_fn( + controller: UniFiController, obj_id: str +) -> timedelta: + """Check if device object is disabled.""" + device = controller.api.devices[obj_id] + return timedelta(seconds=device.next_interval + 60) - allowed_fn: Callable[[UniFiController, str], bool] - api_handler_fn: Callable[[aiounifi.Controller], _HandlerT] - available_fn: Callable[[UniFiController, str], bool] - event_is_on: tuple[EventKey, ...] | None - event_to_subscribe: tuple[EventKey, ...] | None - is_connected_fn: Callable[[aiounifi.Controller, _DataT], bool] - name_fn: Callable[[_DataT], str | None] - object_fn: Callable[[aiounifi.Controller, str], _DataT] - supported_fn: Callable[[aiounifi.Controller, str], bool | None] - unique_id_fn: Callable[[str], str] + +@dataclass +class UnifiEntityTrackerDescriptionMixin(Generic[_HandlerT, _DataT]): + """Device tracker local functions.""" + + heartbeat_timedelta_fn: Callable[[UniFiController, str], timedelta] ip_address_fn: Callable[[aiounifi.Controller, str], str] + is_connected_fn: Callable[[UniFiController, str], bool] hostname_fn: Callable[[aiounifi.Controller, str], str | None] @dataclass -class UnifiEntityDescription(EntityDescription, UnifiEntityLoader[_HandlerT, _DataT]): - """Class describing UniFi switch entity.""" +class UnifiTrackerEntityDescription( + UnifiEntityDescription[_HandlerT, _DataT], + UnifiEntityTrackerDescriptionMixin[_HandlerT, _DataT], +): + """Class describing UniFi device tracker entity.""" -ENTITY_DESCRIPTIONS: tuple[UnifiEntityDescription, ...] = ( - UnifiEntityDescription[Devices, Device]( +ENTITY_DESCRIPTIONS: tuple[UnifiTrackerEntityDescription, ...] = ( + UnifiTrackerEntityDescription[Devices, Device]( key="Device scanner", has_entity_name=True, icon="mdi:ethernet", allowed_fn=lambda controller, obj_id: controller.option_track_devices, api_handler_fn=lambda api: api.devices, available_fn=async_device_available_fn, + device_info_fn=lambda api, obj_id: None, event_is_on=None, event_to_subscribe=None, - is_connected_fn=lambda api, device: device.state == 1, + heartbeat_timedelta_fn=async_device_heartbeat_timedelta_fn, + is_connected_fn=lambda ctrlr, obj_id: ctrlr.api.devices[obj_id].state == 1, name_fn=lambda device: device.name or device.model, object_fn=lambda api, obj_id: api.devices[obj_id], - supported_fn=lambda api, obj_id: True, - unique_id_fn=lambda obj_id: obj_id, + supported_fn=lambda controller, obj_id: True, + unique_id_fn=lambda controller, obj_id: obj_id, ip_address_fn=lambda api, obj_id: api.devices[obj_id].ip, hostname_fn=lambda api, obj_id: None, ), @@ -389,36 +394,26 @@ class UniFiClientTracker(UniFiClientBase, ScannerEntity): await self.remove_item({self.client.mac}) -class UnifiScannerEntity(ScannerEntity, Generic[_HandlerT, _DataT]): +class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity): """Representation of a UniFi scanner.""" - entity_description: UnifiEntityDescription[_HandlerT, _DataT] - _attr_should_poll = False + entity_description: UnifiTrackerEntityDescription - def __init__( - self, - obj_id: str, - controller: UniFiController, - description: UnifiEntityDescription[_HandlerT, _DataT], - ) -> None: - """Set up UniFi scanner entity.""" - self._obj_id = obj_id - self.controller = controller - self.entity_description = description + _ignore_events: bool + _is_connected: bool - self._controller_connection_state_changed = False - self._removed = False - self.schedule_update = False + @callback + def async_initiate_state(self) -> None: + """Initiate entity state. - self._attr_available = description.available_fn(controller, obj_id) - self._attr_unique_id: str = description.unique_id_fn(obj_id) - - obj = description.object_fn(controller.api, obj_id) - self._is_connected = description.is_connected_fn(controller.api, obj) - self._attr_name = description.name_fn(obj) + Initiate is_connected. + """ + description = self.entity_description + self._ignore_events = False + self._is_connected = description.is_connected_fn(self.controller, self._obj_id) @property - def is_connected(self): + def is_connected(self) -> bool: """Return true if the device is connected to the network.""" return self._is_connected @@ -447,36 +442,45 @@ class UnifiScannerEntity(ScannerEntity, Generic[_HandlerT, _DataT]): """Return a unique ID.""" return self._attr_unique_id + @callback + def _make_disconnected(self, *_) -> None: + """No heart beat by device.""" + self._is_connected = False + self.async_write_ha_state() + + @callback + def async_update_state(self, event: ItemEvent, obj_id: str) -> None: + """Update entity state. + + Remove heartbeat check if controller state has changed + and entity is unavailable. + Update is_connected. + Schedule new heartbeat check if connected. + """ + description = self.entity_description + + if event == ItemEvent.CHANGED: + # Prioritize normal data updates over events + self._ignore_events = True + + elif event == ItemEvent.ADDED and not self.available: + # From unifi.entity.async_signal_reachable_callback + # Controller connection state has changed and entity is unavailable + # Cancel heartbeat + self.controller.async_heartbeat(self.unique_id) + return + + if is_connected := description.is_connected_fn(self.controller, self._obj_id): + self._is_connected = is_connected + self.controller.async_heartbeat( + self.unique_id, + dt_util.utcnow() + + description.heartbeat_timedelta_fn(self.controller, self._obj_id), + ) + async def async_added_to_hass(self) -> None: """Register callbacks.""" - description = self.entity_description - handler = description.api_handler_fn(self.controller.api) - self.async_on_remove( - handler.subscribe( - self.async_signalling_callback, - ) - ) - self.async_on_remove( - async_dispatcher_connect( - self.hass, - self.controller.signal_reachable, - self.async_signal_reachable_callback, - ) - ) - self.async_on_remove( - async_dispatcher_connect( - self.hass, - self.controller.signal_options_update, - self.options_updated, - ) - ) - self.async_on_remove( - async_dispatcher_connect( - self.hass, - self.controller.signal_remove, - self.remove_item, - ) - ) + await super().async_added_to_hass() self.async_on_remove( async_dispatcher_connect( self.hass, @@ -485,81 +489,7 @@ class UnifiScannerEntity(ScannerEntity, Generic[_HandlerT, _DataT]): ) ) - @callback - def _make_disconnected(self, *_): - """No heart beat by device.""" - self._is_connected = False - self.async_write_ha_state() - - @callback - def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None: - """Update the switch state.""" - if event == ItemEvent.DELETED and obj_id == self._obj_id: - self.hass.async_create_task(self.remove_item({self._obj_id})) - return - - description = self.entity_description - if not description.supported_fn(self.controller.api, self._obj_id): - self.hass.async_create_task(self.remove_item({self._obj_id})) - return - - if self._controller_connection_state_changed: - self._controller_connection_state_changed = False - - if self.controller.available: - if self._is_connected: - self.schedule_update = True - - else: - self.controller.async_heartbeat(self.unique_id) - else: - self._is_connected = True - self.schedule_update = True - - if self.schedule_update: - device = self.entity_description.object_fn( - self.controller.api, self._obj_id - ) - self.schedule_update = False - self.controller.async_heartbeat( - self.unique_id, - dt_util.utcnow() + timedelta(seconds=device.next_interval + 60), - ) - - self._attr_available = description.available_fn(self.controller, self._obj_id) - self.async_write_ha_state() - - @callback - def async_signal_reachable_callback(self) -> None: - """Call when controller connection state change.""" - self.async_signalling_callback(ItemEvent.ADDED, self._obj_id) - - @callback - def async_event_callback(self, event: Event) -> None: - """Event subscription callback.""" - if event.mac != self._obj_id: - return - - description = self.entity_description - assert isinstance(description.event_to_subscribe, tuple) - assert isinstance(description.event_is_on, tuple) - - if event.key in description.event_to_subscribe: - self._is_connected = event.key in description.event_is_on - self._attr_available = description.available_fn(self.controller, self._obj_id) - self.async_write_ha_state() - - async def options_updated(self) -> None: - """Config entry options are updated, remove entity if option is disabled.""" - if not self.entity_description.allowed_fn(self.controller, self._obj_id): - await self.remove_item({self._obj_id}) - - async def remove_item(self, keys: set) -> None: - """Remove entity if object ID is part of set.""" - if self._obj_id not in keys or self._removed: - return - self._removed = True - if self.registry_entry: - er.async_get(self.hass).async_remove(self.entity_id) - else: - await self.async_remove(force_remove=True) + async def async_will_remove_from_hass(self) -> None: + """Disconnect object when removed.""" + await super().async_will_remove_from_hass() + self.controller.async_heartbeat(self.unique_id) diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 596ecd46cd3..4aacc239b22 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -313,6 +313,7 @@ async def test_tracked_devices( # State change signalling work device_1["next_interval"] = 20 + device_2["state"] = 1 device_2["next_interval"] = 50 mock_unifi_websocket(message=MessageKey.DEVICE, data=[device_1, device_2]) await hass.async_block_till_done()