diff --git a/homeassistant/components/unifi/button.py b/homeassistant/components/unifi/button.py index 0235f6156cc..7471675123a 100644 --- a/homeassistant/components/unifi/button.py +++ b/homeassistant/components/unifi/button.py @@ -24,7 +24,6 @@ from homeassistant.const import EntityCategory from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController from .entity import ( HandlerT, @@ -87,13 +86,13 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up button platform for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - - if not controller.is_admin: - return - - controller.register_platform_add_entities( - UnifiButtonEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, + config_entry, + async_add_entities, + UnifiButtonEntity, + ENTITY_DESCRIPTIONS, + requires_admin=True, ) diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index ba188f80135..9f965b424ff 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -21,14 +21,9 @@ from homeassistant.const import ( CONF_PORT, CONF_USERNAME, CONF_VERIFY_SSL, - Platform, ) from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback -from homeassistant.helpers import ( - aiohttp_client, - device_registry as dr, - entity_registry as er, -) +from homeassistant.helpers import aiohttp_client, device_registry as dr from homeassistant.helpers.device_registry import ( DeviceEntry, DeviceEntryType, @@ -39,13 +34,11 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_send, ) from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.entity_registry import async_entries_for_config_entry from homeassistant.helpers.event import async_call_later, async_track_time_interval import homeassistant.util.dt as dt_util from .const import ( ATTR_MANUFACTURER, - BLOCK_SWITCH, CONF_ALLOW_BANDWIDTH_SENSORS, CONF_ALLOW_UPTIME_SENSORS, CONF_BLOCK_CLIENT, @@ -162,6 +155,24 @@ class UniFiController: host: str = self.config_entry.data[CONF_HOST] return host + @callback + @staticmethod + def register_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + entity_class: type[UnifiEntity], + descriptions: tuple[UnifiEntityDescription, ...], + requires_admin: bool = False, + ) -> None: + """Register platform for UniFi entity management.""" + controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + if requires_admin and not controller.is_admin: + return + controller.register_platform_add_entities( + entity_class, descriptions, async_add_entities + ) + @callback def register_platform_add_entities( self, @@ -251,30 +262,9 @@ class UniFiController: assert self.config_entry.unique_id is not None self.is_admin = self.api.sites[self.config_entry.unique_id].role == "admin" - # Restore clients that are not a part of active clients list. - entity_registry = er.async_get(self.hass) - for entry in async_entries_for_config_entry( - entity_registry, self.config_entry.entry_id - ): - if entry.domain == Platform.DEVICE_TRACKER: - mac = entry.unique_id.split("-", 1)[0] - elif entry.domain == Platform.SWITCH and entry.unique_id.startswith( - BLOCK_SWITCH - ): - mac = entry.unique_id.split("-", 1)[1] - else: - continue - - if mac in self.api.clients or mac not in self.api.clients_all: - continue - - client = self.api.clients_all[mac] - self.api.clients.process_raw([dict(client.raw)]) - LOGGER.debug( - "Restore disconnected client %s (%s)", - entry.entity_id, - client.mac, - ) + for mac in self.option_block_clients: + if mac not in self.api.clients and mac in self.api.clients_all: + self.api.clients.process_raw([dict(self.api.clients_all[mac].raw)]) self.wireless_clients.update_clients(set(self.api.clients.values())) diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index fcfe71a2858..2b7ac04cc0d 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -24,7 +24,6 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util -from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController from .entity import ( HandlerT, @@ -206,9 +205,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up device tracker for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - controller.register_platform_add_entities( - UnifiScannerEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, config_entry, async_add_entities, UnifiScannerEntity, ENTITY_DESCRIPTIONS ) diff --git a/homeassistant/components/unifi/image.py b/homeassistant/components/unifi/image.py index 8231b87ee85..2318702f0d1 100644 --- a/homeassistant/components/unifi/image.py +++ b/homeassistant/components/unifi/image.py @@ -20,7 +20,6 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util -from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController from .entity import ( HandlerT, @@ -83,13 +82,13 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up image platform for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - - if not controller.is_admin: - return - - controller.register_platform_add_entities( - UnifiImageEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, + config_entry, + async_add_entities, + UnifiImageEntity, + ENTITY_DESCRIPTIONS, + requires_admin=True, ) diff --git a/homeassistant/components/unifi/sensor.py b/homeassistant/components/unifi/sensor.py index 7cb0b2bbfe3..86c6b0d6352 100644 --- a/homeassistant/components/unifi/sensor.py +++ b/homeassistant/components/unifi/sensor.py @@ -35,7 +35,6 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util -from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController from .entity import ( HandlerT, @@ -329,9 +328,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up sensors for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - controller.register_platform_add_entities( - UnifiSensorEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, config_entry, async_add_entities, UnifiSensorEntity, ENTITY_DESCRIPTIONS ) diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index 560e150e63c..0aa39914686 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -43,7 +43,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN +from .const import ATTR_MANUFACTURER from .controller import UniFiController from .entity import ( HandlerT, @@ -320,19 +320,13 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up switches for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - - if not controller.is_admin: - return - - for mac in controller.option_block_clients: - if mac not in controller.api.clients and mac in controller.api.clients_all: - controller.api.clients.process_raw( - [dict(controller.api.clients_all[mac].raw)] - ) - - controller.register_platform_add_entities( - UnifiSwitchEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, + config_entry, + async_add_entities, + UnifiSwitchEntity, + ENTITY_DESCRIPTIONS, + requires_admin=True, ) diff --git a/homeassistant/components/unifi/update.py b/homeassistant/components/unifi/update.py index 6526a02da83..65b26736cf1 100644 --- a/homeassistant/components/unifi/update.py +++ b/homeassistant/components/unifi/update.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable, Coroutine from dataclasses import dataclass import logging -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import Any, Generic, TypeVar import aiounifi from aiounifi.interfaces.api_handlers import ItemEvent @@ -21,7 +21,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import DOMAIN as UNIFI_DOMAIN +from .controller import UniFiController from .entity import ( UnifiEntity, UnifiEntityDescription, @@ -29,9 +29,6 @@ from .entity import ( async_device_device_info_fn, ) -if TYPE_CHECKING: - from .controller import UniFiController - LOGGER = logging.getLogger(__name__) _DataT = TypeVar("_DataT", bound=Device) @@ -88,9 +85,12 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up update entities for UniFi Network integration.""" - controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - controller.register_platform_add_entities( - UnifiDeviceUpdateEntity, ENTITY_DESCRIPTIONS, async_add_entities + UniFiController.register_platform( + hass, + config_entry, + async_add_entities, + UnifiDeviceUpdateEntity, + ENTITY_DESCRIPTIONS, ) diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 7b939077e48..99874b3a949 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -946,7 +946,7 @@ async def test_restoring_client( await setup_unifi_integration( hass, aioclient_mock, - options={CONF_BLOCK_CLIENT: True}, + options={CONF_BLOCK_CLIENT: [restored["mac"]]}, clients_response=[client], clients_all_response=[restored, not_restored], )