From 115a1ceea0c568ffb466e03b4ab322f253ccf151 Mon Sep 17 00:00:00 2001 From: Robert Svensson Date: Tue, 25 Oct 2022 22:36:51 +0200 Subject: [PATCH] Rewrite UniFi block client switch (#80969) * Refactor UniFi block client switch entities * Use new switch loader * Rename lambdas * Use is_on rather than _attr_is_on when applicable --- homeassistant/components/unifi/switch.py | 176 +++++++++++++++-------- tests/components/unifi/test_switch.py | 28 ++-- 2 files changed, 124 insertions(+), 80 deletions(-) diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index fbfb6b7335a..65d0041187e 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -12,17 +12,17 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar from aiounifi.interfaces.api_handlers import ItemEvent +from aiounifi.interfaces.clients import Clients from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups from aiounifi.interfaces.outlets import Outlets from aiounifi.interfaces.ports import Ports -from aiounifi.models.api import SOURCE_EVENT from aiounifi.models.client import ClientBlockRequest from aiounifi.models.device import ( DeviceSetOutletRelayRequest, DeviceSetPoePortModeRequest, ) from aiounifi.models.dpi_restriction_app import DPIRestrictionAppEnableRequest -from aiounifi.models.event import EventKey +from aiounifi.models.event import Event, EventKey from homeassistant.components.switch import DOMAIN, SwitchDeviceClass, SwitchEntity from homeassistant.config_entries import ConfigEntry @@ -58,12 +58,12 @@ T = TypeVar("T") class UnifiEntityLoader(Generic[T]): """Validate and load entities from different UniFi handlers.""" - config_option_fn: Callable[[UniFiController], bool] - entity_cls: type[UnifiDPIRestrictionSwitch] | type[UnifiOutletSwitch] | type[ - UnifiPoePortSwitch - ] | type[UnifiDPIRestrictionSwitch] + allowed_fn: Callable[[UniFiController, str], bool] + entity_cls: type[UnifiBlockClientSwitch] | type[UnifiDPIRestrictionSwitch] | type[ + UnifiOutletSwitch + ] | type[UnifiPoePortSwitch] | type[UnifiDPIRestrictionSwitch] handler_fn: Callable[[UniFiController], T] - value_fn: Callable[[T, str], bool | None] + supported_fn: Callable[[T, str], bool | None] async def async_setup_entry( @@ -113,9 +113,6 @@ async def async_setup_entry( devices: set = controller.api.devices, ) -> None: """Update the values of the controller.""" - if controller.option_block_clients: - add_block_entities(controller, async_add_entities, clients) - if controller.option_poe_clients: add_poe_entities(controller, async_add_entities, clients, known_poe_clients) @@ -136,14 +133,14 @@ async def async_setup_entry( @callback def async_create_entity(event: ItemEvent, obj_id: str) -> None: """Create UniFi entity.""" - if not loader.config_option_fn(controller) or not loader.value_fn( + if not loader.allowed_fn(controller, obj_id) or not loader.supported_fn( api_handler, obj_id ): return entity = loader.entity_cls(obj_id, controller) if event == ItemEvent.ADDED: - async_add_entities(entities) + async_add_entities([entity]) return entities.append(entity) @@ -157,21 +154,6 @@ async def async_setup_entry( async_load_entities(unifi_loader) -@callback -def add_block_entities(controller, async_add_entities, clients): - """Add new switch entities from the controller.""" - switches = [] - - for mac in controller.option_block_clients: - if mac in controller.entities[DOMAIN][BLOCK_SWITCH] or mac not in clients: - continue - - client = controller.api.clients[mac] - switches.append(UniFiBlockClientSwitch(client, controller)) - - async_add_entities(switches) - - @callback def add_poe_entities(controller, async_add_entities, clients, known_poe_clients): """Add new switch entities from the controller.""" @@ -319,59 +301,123 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity): await self.remove_item({self.client.mac}) -class UniFiBlockClientSwitch(UniFiClient, SwitchEntity): +class UnifiBlockClientSwitch(SwitchEntity): """Representation of a blockable client.""" - DOMAIN = DOMAIN - TYPE = BLOCK_SWITCH - + _attr_device_class = SwitchDeviceClass.SWITCH _attr_entity_category = EntityCategory.CONFIG + _attr_has_entity_name = True + _attr_icon = "mdi:ethernet" + _attr_should_poll = False - def __init__(self, client, controller): + def __init__(self, obj_id: str, controller: UniFiController) -> None: """Set up block switch.""" - super().__init__(client, controller) + controller.entities[DOMAIN][BLOCK_SWITCH].add(obj_id) + self._obj_id = obj_id + self.controller = controller - self._is_blocked = client.blocked + self._removed = False + + client = controller.api.clients[obj_id] + self._attr_available = controller.available + self._attr_is_on = not client.blocked + self._attr_unique_id = f"{BLOCK_SWITCH}-{obj_id}" + self._attr_device_info = DeviceInfo( + connections={(CONNECTION_NETWORK_MAC, obj_id)}, + default_manufacturer=client.oui, + default_name=client.name or client.hostname, + ) + + async def async_added_to_hass(self) -> None: + """Entity created.""" + self.async_on_remove( + self.controller.api.clients.subscribe(self.async_signalling_callback) + ) + self.async_on_remove( + self.controller.api.events.subscribe( + self.async_event_callback, CLIENT_BLOCKED + CLIENT_UNBLOCKED + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, self.controller.signal_remove, self.remove_item + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, self.controller.signal_options_update, self.options_updated + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, + self.controller.signal_reachable, + self.async_signal_reachable_callback, + ) + ) + + async def async_will_remove_from_hass(self) -> None: + """Disconnect object when removed.""" + self.controller.entities[DOMAIN][BLOCK_SWITCH].remove(self._obj_id) @callback - def async_update_callback(self) -> None: + def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None: """Update the clients state.""" - if ( - self.client.last_updated == SOURCE_EVENT - and self.client.event.key in CLIENT_BLOCKED + CLIENT_UNBLOCKED - ): - self._is_blocked = self.client.event.key in CLIENT_BLOCKED + if event == ItemEvent.DELETED: + self.hass.async_create_task(self.remove_item({self._obj_id})) + return - super().async_update_callback() + client = self.controller.api.clients[self._obj_id] + self._attr_is_on = not client.blocked + self._attr_available = self.controller.available + self.async_write_ha_state() - @property - def is_on(self): - """Return true if client is allowed to connect.""" - return not self._is_blocked + @callback + def async_event_callback(self, event: Event) -> None: + """Event subscription callback.""" + if event.mac != self._obj_id: + return + if event.key in CLIENT_BLOCKED + CLIENT_UNBLOCKED: + self._attr_is_on = event.key in CLIENT_UNBLOCKED + self._attr_available = self.controller.available + self.async_write_ha_state() + + @callback + def async_signal_reachable_callback(self) -> None: + """Call when controller connection state change.""" + self.async_signalling_callback(ItemEvent.ADDED, self._obj_id) async def async_turn_on(self, **kwargs: Any) -> None: """Turn on connectivity for client.""" await self.controller.api.request( - ClientBlockRequest.create(self.client.mac, False) + ClientBlockRequest.create(self._obj_id, False) ) async def async_turn_off(self, **kwargs: Any) -> None: """Turn off connectivity for client.""" - await self.controller.api.request( - ClientBlockRequest.create(self.client.mac, True) - ) + await self.controller.api.request(ClientBlockRequest.create(self._obj_id, True)) @property def icon(self) -> str: """Return the icon to use in the frontend.""" - if self._is_blocked: + if not self.is_on: return "mdi:network-off" return "mdi:network" async def options_updated(self) -> None: """Config entry options are updated, remove entity if option is disabled.""" - if self.client.mac not in self.controller.option_block_clients: - await self.remove_item({self.client.mac}) + if self._obj_id not in self.controller.option_block_clients: + await self.remove_item({self._obj_id}) + + async def remove_item(self, keys: set) -> None: + """Remove entity if key is part of set.""" + if self._obj_id not in keys or self._removed: + return + self._removed = True + if self.registry_entry: + er.async_get(self.hass).async_remove(self.entity_id) + else: + await self.async_remove(force_remove=True) class UnifiDPIRestrictionSwitch(SwitchEntity): @@ -379,7 +425,7 @@ class UnifiDPIRestrictionSwitch(SwitchEntity): _attr_entity_category = EntityCategory.CONFIG - def __init__(self, obj_id: str, controller): + def __init__(self, obj_id: str, controller: UniFiController) -> None: """Set up dpi switch.""" controller.entities[DOMAIN][DPI_SWITCH].add(obj_id) self._obj_id = obj_id @@ -456,7 +502,7 @@ class UnifiDPIRestrictionSwitch(SwitchEntity): @property def icon(self): """Return the icon to use in the frontend.""" - if self._attr_is_on: + if self.is_on: return "mdi:network" return "mdi:network-off" @@ -516,7 +562,7 @@ class UnifiOutletSwitch(SwitchEntity): _attr_has_entity_name = True _attr_should_poll = False - def __init__(self, obj_id: str, controller) -> None: + def __init__(self, obj_id: str, controller: UniFiController) -> None: """Set up UniFi Network entity base.""" self._device_mac, index = obj_id.split("_", 1) self._index = int(index) @@ -591,7 +637,7 @@ class UnifiPoePortSwitch(SwitchEntity): _attr_icon = "mdi:ethernet" _attr_should_poll = False - def __init__(self, obj_id: str, controller) -> None: + def __init__(self, obj_id: str, controller: UniFiController) -> None: """Set up UniFi Network entity base.""" self._device_mac, index = obj_id.split("_", 1) self._index = int(index) @@ -657,22 +703,28 @@ class UnifiPoePortSwitch(SwitchEntity): UNIFI_LOADERS: tuple[UnifiEntityLoader, ...] = ( + UnifiEntityLoader[Clients]( + allowed_fn=lambda controller, obj_id: obj_id in controller.option_block_clients, + entity_cls=UnifiBlockClientSwitch, + handler_fn=lambda contrlr: contrlr.api.clients, + supported_fn=lambda handler, obj_id: True, + ), UnifiEntityLoader[DPIRestrictionGroups]( - config_option_fn=lambda controller: controller.option_dpi_restrictions, + allowed_fn=lambda controller, obj_id: controller.option_dpi_restrictions, entity_cls=UnifiDPIRestrictionSwitch, handler_fn=lambda controller: controller.api.dpi_groups, - value_fn=lambda handler, index: bool(handler[index].dpiapp_ids), + supported_fn=lambda handler, obj_id: bool(handler[obj_id].dpiapp_ids), ), UnifiEntityLoader[Outlets]( - config_option_fn=lambda controller: True, + allowed_fn=lambda controller, obj_id: True, entity_cls=UnifiOutletSwitch, handler_fn=lambda controller: controller.api.outlets, - value_fn=lambda handler, index: handler[index].has_relay, + supported_fn=lambda handler, obj_id: handler[obj_id].has_relay, ), UnifiEntityLoader[Ports]( - config_option_fn=lambda controller: True, + allowed_fn=lambda controller, obj_id: True, entity_cls=UnifiPoePortSwitch, handler_fn=lambda controller: controller.api.ports, - value_fn=lambda handler, index: handler[index].port_poe, + supported_fn=lambda handler, obj_id: handler[obj_id].port_poe, ), ) diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index db0b358179c..e6357b03172 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -54,7 +54,7 @@ CLIENT_1 = { "mac": "00:00:00:00:00:01", "name": "POE Client 1", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 1, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -67,7 +67,7 @@ CLIENT_2 = { "mac": "00:00:00:00:00:02", "name": "POE Client 2", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 2, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -80,7 +80,7 @@ CLIENT_3 = { "mac": "00:00:00:00:00:03", "name": "Non-POE Client 3", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 3, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -93,7 +93,7 @@ CLIENT_4 = { "mac": "00:00:00:00:00:04", "name": "Non-POE Client 4", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 4, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -107,7 +107,7 @@ POE_SWITCH_CLIENTS = [ "mac": "00:00:00:00:00:01", "name": "POE Client 1", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 1, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -120,7 +120,7 @@ POE_SWITCH_CLIENTS = [ "mac": "00:00:00:00:00:02", "name": "POE Client 2", "oui": "Producer", - "sw_mac": "00:00:00:00:01:01", + "sw_mac": "10:00:00:00:01:01", "sw_port": 1, "wired-rx_bytes": 1234000000, "wired-tx_bytes": 5678000000, @@ -131,7 +131,7 @@ DEVICE_1 = { "board_rev": 2, "device_id": "mock-id", "ip": "10.0.1.1", - "mac": "00:00:00:00:01:01", + "mac": "10:00:00:00:01:01", "last_seen": 1562600145, "model": "US16P150", "name": "mock-name", @@ -650,7 +650,7 @@ async def test_switches(hass, aioclient_mock): assert switch_1 is not None assert switch_1.state == "on" assert switch_1.attributes["power"] == "2.56" - assert switch_1.attributes[SWITCH_DOMAIN] == "00:00:00:00:01:01" + assert switch_1.attributes[SWITCH_DOMAIN] == "10:00:00:00:01:01" assert switch_1.attributes["port"] == 1 assert switch_1.attributes["poe_mode"] == "auto" @@ -1027,21 +1027,13 @@ async def test_new_client_discovered_on_block_control( ) assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 - - blocked = hass.states.get("switch.block_client_1") - assert blocked is None + assert hass.states.get("switch.block_client_1") is None mock_unifi_websocket(message=MessageKey.CLIENT, data=BLOCKED) await hass.async_block_till_done() - assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 - - mock_unifi_websocket(message=MessageKey.EVENT, data=EVENT_BLOCKED_CLIENT_CONNECTED) - await hass.async_block_till_done() - assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 - blocked = hass.states.get("switch.block_client_1") - assert blocked is not None + assert hass.states.get("switch.block_client_1") is not None async def test_option_block_clients(hass, aioclient_mock):