Rewrite UniFi block client switch (#80969)

* Refactor UniFi block client switch entities

* Use new switch loader

* Rename lambdas

* Use is_on rather than _attr_is_on when applicable
This commit is contained in:
Robert Svensson 2022-10-25 22:36:51 +02:00 committed by GitHub
parent 2af58ad609
commit 115a1ceea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 80 deletions

View File

@ -12,17 +12,17 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from aiounifi.interfaces.api_handlers import ItemEvent from aiounifi.interfaces.api_handlers import ItemEvent
from aiounifi.interfaces.clients import Clients
from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups
from aiounifi.interfaces.outlets import Outlets from aiounifi.interfaces.outlets import Outlets
from aiounifi.interfaces.ports import Ports from aiounifi.interfaces.ports import Ports
from aiounifi.models.api import SOURCE_EVENT
from aiounifi.models.client import ClientBlockRequest from aiounifi.models.client import ClientBlockRequest
from aiounifi.models.device import ( from aiounifi.models.device import (
DeviceSetOutletRelayRequest, DeviceSetOutletRelayRequest,
DeviceSetPoePortModeRequest, DeviceSetPoePortModeRequest,
) )
from aiounifi.models.dpi_restriction_app import DPIRestrictionAppEnableRequest from aiounifi.models.dpi_restriction_app import DPIRestrictionAppEnableRequest
from aiounifi.models.event import EventKey from aiounifi.models.event import Event, EventKey
from homeassistant.components.switch import DOMAIN, SwitchDeviceClass, SwitchEntity from homeassistant.components.switch import DOMAIN, SwitchDeviceClass, SwitchEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -58,12 +58,12 @@ T = TypeVar("T")
class UnifiEntityLoader(Generic[T]): class UnifiEntityLoader(Generic[T]):
"""Validate and load entities from different UniFi handlers.""" """Validate and load entities from different UniFi handlers."""
config_option_fn: Callable[[UniFiController], bool] allowed_fn: Callable[[UniFiController, str], bool]
entity_cls: type[UnifiDPIRestrictionSwitch] | type[UnifiOutletSwitch] | type[ entity_cls: type[UnifiBlockClientSwitch] | type[UnifiDPIRestrictionSwitch] | type[
UnifiPoePortSwitch UnifiOutletSwitch
] | type[UnifiDPIRestrictionSwitch] ] | type[UnifiPoePortSwitch] | type[UnifiDPIRestrictionSwitch]
handler_fn: Callable[[UniFiController], T] handler_fn: Callable[[UniFiController], T]
value_fn: Callable[[T, str], bool | None] supported_fn: Callable[[T, str], bool | None]
async def async_setup_entry( async def async_setup_entry(
@ -113,9 +113,6 @@ async def async_setup_entry(
devices: set = controller.api.devices, devices: set = controller.api.devices,
) -> None: ) -> None:
"""Update the values of the controller.""" """Update the values of the controller."""
if controller.option_block_clients:
add_block_entities(controller, async_add_entities, clients)
if controller.option_poe_clients: if controller.option_poe_clients:
add_poe_entities(controller, async_add_entities, clients, known_poe_clients) add_poe_entities(controller, async_add_entities, clients, known_poe_clients)
@ -136,14 +133,14 @@ async def async_setup_entry(
@callback @callback
def async_create_entity(event: ItemEvent, obj_id: str) -> None: def async_create_entity(event: ItemEvent, obj_id: str) -> None:
"""Create UniFi entity.""" """Create UniFi entity."""
if not loader.config_option_fn(controller) or not loader.value_fn( if not loader.allowed_fn(controller, obj_id) or not loader.supported_fn(
api_handler, obj_id api_handler, obj_id
): ):
return return
entity = loader.entity_cls(obj_id, controller) entity = loader.entity_cls(obj_id, controller)
if event == ItemEvent.ADDED: if event == ItemEvent.ADDED:
async_add_entities(entities) async_add_entities([entity])
return return
entities.append(entity) entities.append(entity)
@ -157,21 +154,6 @@ async def async_setup_entry(
async_load_entities(unifi_loader) async_load_entities(unifi_loader)
@callback
def add_block_entities(controller, async_add_entities, clients):
"""Add new switch entities from the controller."""
switches = []
for mac in controller.option_block_clients:
if mac in controller.entities[DOMAIN][BLOCK_SWITCH] or mac not in clients:
continue
client = controller.api.clients[mac]
switches.append(UniFiBlockClientSwitch(client, controller))
async_add_entities(switches)
@callback @callback
def add_poe_entities(controller, async_add_entities, clients, known_poe_clients): def add_poe_entities(controller, async_add_entities, clients, known_poe_clients):
"""Add new switch entities from the controller.""" """Add new switch entities from the controller."""
@ -319,59 +301,123 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity):
await self.remove_item({self.client.mac}) await self.remove_item({self.client.mac})
class UniFiBlockClientSwitch(UniFiClient, SwitchEntity): class UnifiBlockClientSwitch(SwitchEntity):
"""Representation of a blockable client.""" """Representation of a blockable client."""
DOMAIN = DOMAIN _attr_device_class = SwitchDeviceClass.SWITCH
TYPE = BLOCK_SWITCH
_attr_entity_category = EntityCategory.CONFIG _attr_entity_category = EntityCategory.CONFIG
_attr_has_entity_name = True
_attr_icon = "mdi:ethernet"
_attr_should_poll = False
def __init__(self, client, controller): def __init__(self, obj_id: str, controller: UniFiController) -> None:
"""Set up block switch.""" """Set up block switch."""
super().__init__(client, controller) controller.entities[DOMAIN][BLOCK_SWITCH].add(obj_id)
self._obj_id = obj_id
self.controller = controller
self._is_blocked = client.blocked self._removed = False
client = controller.api.clients[obj_id]
self._attr_available = controller.available
self._attr_is_on = not client.blocked
self._attr_unique_id = f"{BLOCK_SWITCH}-{obj_id}"
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, obj_id)},
default_manufacturer=client.oui,
default_name=client.name or client.hostname,
)
async def async_added_to_hass(self) -> None:
"""Entity created."""
self.async_on_remove(
self.controller.api.clients.subscribe(self.async_signalling_callback)
)
self.async_on_remove(
self.controller.api.events.subscribe(
self.async_event_callback, CLIENT_BLOCKED + CLIENT_UNBLOCKED
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_remove, self.remove_item
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_options_update, self.options_updated
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
self.controller.signal_reachable,
self.async_signal_reachable_callback,
)
)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect object when removed."""
self.controller.entities[DOMAIN][BLOCK_SWITCH].remove(self._obj_id)
@callback @callback
def async_update_callback(self) -> None: def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None:
"""Update the clients state.""" """Update the clients state."""
if ( if event == ItemEvent.DELETED:
self.client.last_updated == SOURCE_EVENT self.hass.async_create_task(self.remove_item({self._obj_id}))
and self.client.event.key in CLIENT_BLOCKED + CLIENT_UNBLOCKED return
):
self._is_blocked = self.client.event.key in CLIENT_BLOCKED
super().async_update_callback() client = self.controller.api.clients[self._obj_id]
self._attr_is_on = not client.blocked
self._attr_available = self.controller.available
self.async_write_ha_state()
@property @callback
def is_on(self): def async_event_callback(self, event: Event) -> None:
"""Return true if client is allowed to connect.""" """Event subscription callback."""
return not self._is_blocked if event.mac != self._obj_id:
return
if event.key in CLIENT_BLOCKED + CLIENT_UNBLOCKED:
self._attr_is_on = event.key in CLIENT_UNBLOCKED
self._attr_available = self.controller.available
self.async_write_ha_state()
@callback
def async_signal_reachable_callback(self) -> None:
"""Call when controller connection state change."""
self.async_signalling_callback(ItemEvent.ADDED, self._obj_id)
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on connectivity for client.""" """Turn on connectivity for client."""
await self.controller.api.request( await self.controller.api.request(
ClientBlockRequest.create(self.client.mac, False) ClientBlockRequest.create(self._obj_id, False)
) )
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off connectivity for client.""" """Turn off connectivity for client."""
await self.controller.api.request( await self.controller.api.request(ClientBlockRequest.create(self._obj_id, True))
ClientBlockRequest.create(self.client.mac, True)
)
@property @property
def icon(self) -> str: def icon(self) -> str:
"""Return the icon to use in the frontend.""" """Return the icon to use in the frontend."""
if self._is_blocked: if not self.is_on:
return "mdi:network-off" return "mdi:network-off"
return "mdi:network" return "mdi:network"
async def options_updated(self) -> None: async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled.""" """Config entry options are updated, remove entity if option is disabled."""
if self.client.mac not in self.controller.option_block_clients: if self._obj_id not in self.controller.option_block_clients:
await self.remove_item({self.client.mac}) await self.remove_item({self._obj_id})
async def remove_item(self, keys: set) -> None:
"""Remove entity if key is part of set."""
if self._obj_id not in keys or self._removed:
return
self._removed = True
if self.registry_entry:
er.async_get(self.hass).async_remove(self.entity_id)
else:
await self.async_remove(force_remove=True)
class UnifiDPIRestrictionSwitch(SwitchEntity): class UnifiDPIRestrictionSwitch(SwitchEntity):
@ -379,7 +425,7 @@ class UnifiDPIRestrictionSwitch(SwitchEntity):
_attr_entity_category = EntityCategory.CONFIG _attr_entity_category = EntityCategory.CONFIG
def __init__(self, obj_id: str, controller): def __init__(self, obj_id: str, controller: UniFiController) -> None:
"""Set up dpi switch.""" """Set up dpi switch."""
controller.entities[DOMAIN][DPI_SWITCH].add(obj_id) controller.entities[DOMAIN][DPI_SWITCH].add(obj_id)
self._obj_id = obj_id self._obj_id = obj_id
@ -456,7 +502,7 @@ class UnifiDPIRestrictionSwitch(SwitchEntity):
@property @property
def icon(self): def icon(self):
"""Return the icon to use in the frontend.""" """Return the icon to use in the frontend."""
if self._attr_is_on: if self.is_on:
return "mdi:network" return "mdi:network"
return "mdi:network-off" return "mdi:network-off"
@ -516,7 +562,7 @@ class UnifiOutletSwitch(SwitchEntity):
_attr_has_entity_name = True _attr_has_entity_name = True
_attr_should_poll = False _attr_should_poll = False
def __init__(self, obj_id: str, controller) -> None: def __init__(self, obj_id: str, controller: UniFiController) -> None:
"""Set up UniFi Network entity base.""" """Set up UniFi Network entity base."""
self._device_mac, index = obj_id.split("_", 1) self._device_mac, index = obj_id.split("_", 1)
self._index = int(index) self._index = int(index)
@ -591,7 +637,7 @@ class UnifiPoePortSwitch(SwitchEntity):
_attr_icon = "mdi:ethernet" _attr_icon = "mdi:ethernet"
_attr_should_poll = False _attr_should_poll = False
def __init__(self, obj_id: str, controller) -> None: def __init__(self, obj_id: str, controller: UniFiController) -> None:
"""Set up UniFi Network entity base.""" """Set up UniFi Network entity base."""
self._device_mac, index = obj_id.split("_", 1) self._device_mac, index = obj_id.split("_", 1)
self._index = int(index) self._index = int(index)
@ -657,22 +703,28 @@ class UnifiPoePortSwitch(SwitchEntity):
UNIFI_LOADERS: tuple[UnifiEntityLoader, ...] = ( UNIFI_LOADERS: tuple[UnifiEntityLoader, ...] = (
UnifiEntityLoader[Clients](
allowed_fn=lambda controller, obj_id: obj_id in controller.option_block_clients,
entity_cls=UnifiBlockClientSwitch,
handler_fn=lambda contrlr: contrlr.api.clients,
supported_fn=lambda handler, obj_id: True,
),
UnifiEntityLoader[DPIRestrictionGroups]( UnifiEntityLoader[DPIRestrictionGroups](
config_option_fn=lambda controller: controller.option_dpi_restrictions, allowed_fn=lambda controller, obj_id: controller.option_dpi_restrictions,
entity_cls=UnifiDPIRestrictionSwitch, entity_cls=UnifiDPIRestrictionSwitch,
handler_fn=lambda controller: controller.api.dpi_groups, handler_fn=lambda controller: controller.api.dpi_groups,
value_fn=lambda handler, index: bool(handler[index].dpiapp_ids), supported_fn=lambda handler, obj_id: bool(handler[obj_id].dpiapp_ids),
), ),
UnifiEntityLoader[Outlets]( UnifiEntityLoader[Outlets](
config_option_fn=lambda controller: True, allowed_fn=lambda controller, obj_id: True,
entity_cls=UnifiOutletSwitch, entity_cls=UnifiOutletSwitch,
handler_fn=lambda controller: controller.api.outlets, handler_fn=lambda controller: controller.api.outlets,
value_fn=lambda handler, index: handler[index].has_relay, supported_fn=lambda handler, obj_id: handler[obj_id].has_relay,
), ),
UnifiEntityLoader[Ports]( UnifiEntityLoader[Ports](
config_option_fn=lambda controller: True, allowed_fn=lambda controller, obj_id: True,
entity_cls=UnifiPoePortSwitch, entity_cls=UnifiPoePortSwitch,
handler_fn=lambda controller: controller.api.ports, handler_fn=lambda controller: controller.api.ports,
value_fn=lambda handler, index: handler[index].port_poe, supported_fn=lambda handler, obj_id: handler[obj_id].port_poe,
), ),
) )

View File

@ -54,7 +54,7 @@ CLIENT_1 = {
"mac": "00:00:00:00:00:01", "mac": "00:00:00:00:00:01",
"name": "POE Client 1", "name": "POE Client 1",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 1, "sw_port": 1,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -67,7 +67,7 @@ CLIENT_2 = {
"mac": "00:00:00:00:00:02", "mac": "00:00:00:00:00:02",
"name": "POE Client 2", "name": "POE Client 2",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 2, "sw_port": 2,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -80,7 +80,7 @@ CLIENT_3 = {
"mac": "00:00:00:00:00:03", "mac": "00:00:00:00:00:03",
"name": "Non-POE Client 3", "name": "Non-POE Client 3",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 3, "sw_port": 3,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -93,7 +93,7 @@ CLIENT_4 = {
"mac": "00:00:00:00:00:04", "mac": "00:00:00:00:00:04",
"name": "Non-POE Client 4", "name": "Non-POE Client 4",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 4, "sw_port": 4,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -107,7 +107,7 @@ POE_SWITCH_CLIENTS = [
"mac": "00:00:00:00:00:01", "mac": "00:00:00:00:00:01",
"name": "POE Client 1", "name": "POE Client 1",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 1, "sw_port": 1,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -120,7 +120,7 @@ POE_SWITCH_CLIENTS = [
"mac": "00:00:00:00:00:02", "mac": "00:00:00:00:00:02",
"name": "POE Client 2", "name": "POE Client 2",
"oui": "Producer", "oui": "Producer",
"sw_mac": "00:00:00:00:01:01", "sw_mac": "10:00:00:00:01:01",
"sw_port": 1, "sw_port": 1,
"wired-rx_bytes": 1234000000, "wired-rx_bytes": 1234000000,
"wired-tx_bytes": 5678000000, "wired-tx_bytes": 5678000000,
@ -131,7 +131,7 @@ DEVICE_1 = {
"board_rev": 2, "board_rev": 2,
"device_id": "mock-id", "device_id": "mock-id",
"ip": "10.0.1.1", "ip": "10.0.1.1",
"mac": "00:00:00:00:01:01", "mac": "10:00:00:00:01:01",
"last_seen": 1562600145, "last_seen": 1562600145,
"model": "US16P150", "model": "US16P150",
"name": "mock-name", "name": "mock-name",
@ -650,7 +650,7 @@ async def test_switches(hass, aioclient_mock):
assert switch_1 is not None assert switch_1 is not None
assert switch_1.state == "on" assert switch_1.state == "on"
assert switch_1.attributes["power"] == "2.56" assert switch_1.attributes["power"] == "2.56"
assert switch_1.attributes[SWITCH_DOMAIN] == "00:00:00:00:01:01" assert switch_1.attributes[SWITCH_DOMAIN] == "10:00:00:00:01:01"
assert switch_1.attributes["port"] == 1 assert switch_1.attributes["port"] == 1
assert switch_1.attributes["poe_mode"] == "auto" assert switch_1.attributes["poe_mode"] == "auto"
@ -1027,21 +1027,13 @@ async def test_new_client_discovered_on_block_control(
) )
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
assert hass.states.get("switch.block_client_1") is None
blocked = hass.states.get("switch.block_client_1")
assert blocked is None
mock_unifi_websocket(message=MessageKey.CLIENT, data=BLOCKED) mock_unifi_websocket(message=MessageKey.CLIENT, data=BLOCKED)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
mock_unifi_websocket(message=MessageKey.EVENT, data=EVENT_BLOCKED_CLIENT_CONNECTED)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1
blocked = hass.states.get("switch.block_client_1") assert hass.states.get("switch.block_client_1") is not None
assert blocked is not None
async def test_option_block_clients(hass, aioclient_mock): async def test_option_block_clients(hass, aioclient_mock):