From f1a31d8d333f6c88b7e61719284accf38bea209e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 2 Jun 2022 05:26:08 -1000 Subject: [PATCH] Add support for async_remove_config_entry_device to unifiprotect (#72742) * Add support for async_remove_config_entry_device to unifiprotect * tweaks * tweaks * more cleanups * more cleanups * fix unhelpful auto import * add coverage * fix mac formatting * collapse logic --- .../components/unifiprotect/__init__.py | 22 +++++- homeassistant/components/unifiprotect/data.py | 34 ++++++---- .../components/unifiprotect/entity.py | 6 +- .../components/unifiprotect/services.py | 16 +---- .../components/unifiprotect/utils.py | 38 +++++++++++ tests/components/unifiprotect/test_init.py | 68 ++++++++++++++++++- 6 files changed, 152 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/unifiprotect/__init__.py b/homeassistant/components/unifiprotect/__init__.py index c28f2639e00..4ec11a899e3 100644 --- a/homeassistant/components/unifiprotect/__init__.py +++ b/homeassistant/components/unifiprotect/__init__.py @@ -21,7 +21,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.aiohttp_client import async_create_clientsession from .const import ( @@ -35,9 +35,10 @@ from .const import ( OUTDATED_LOG_MESSAGE, PLATFORMS, ) -from .data import ProtectData +from .data import ProtectData, async_ufp_instance_for_config_entry_ids from .discovery import async_start_discovery from .services import async_cleanup_services, async_setup_services +from .utils import _async_unifi_mac_from_hass, async_get_devices _LOGGER = logging.getLogger(__name__) @@ -166,3 +167,20 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async_cleanup_services(hass) return bool(unload_ok) + + +async def async_remove_config_entry_device( + hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry +) -> bool: + """Remove ufp config entry from a device.""" + unifi_macs = { + _async_unifi_mac_from_hass(connection[1]) + for connection in device_entry.connections + if connection[0] == dr.CONNECTION_NETWORK_MAC + } + api = async_ufp_instance_for_config_entry_ids(hass, {config_entry.entry_id}) + assert api is not None + return api.bootstrap.nvr.mac not in unifi_macs and not any( + device.mac in unifi_macs + for device in async_get_devices(api, DEVICES_THAT_ADOPT) + ) diff --git a/homeassistant/components/unifiprotect/data.py b/homeassistant/components/unifiprotect/data.py index 371c1c7831b..68c8873c17e 100644 --- a/homeassistant/components/unifiprotect/data.py +++ b/homeassistant/components/unifiprotect/data.py @@ -14,13 +14,14 @@ from pyunifiprotect.data import ( ModelType, WSSubscriptionMessage, ) -from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel +from pyunifiprotect.data.base import ProtectAdoptableDeviceModel from homeassistant.config_entries import ConfigEntry from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers.event import async_track_time_interval -from .const import CONF_DISABLE_RTSP, DEVICES_THAT_ADOPT, DEVICES_WITH_ENTITIES +from .const import CONF_DISABLE_RTSP, DEVICES_THAT_ADOPT, DEVICES_WITH_ENTITIES, DOMAIN +from .utils import async_get_adoptable_devices_by_type, async_get_devices _LOGGER = logging.getLogger(__name__) @@ -58,13 +59,10 @@ class ProtectData: self, device_types: Iterable[ModelType] ) -> Generator[ProtectAdoptableDeviceModel, None, None]: """Get all devices matching types.""" - for device_type in device_types: - attr = f"{device_type.value}s" - devices: dict[str, ProtectAdoptableDeviceModel] = getattr( - self.api.bootstrap, attr - ) - yield from devices.values() + yield from async_get_adoptable_devices_by_type( + self.api, device_type + ).values() async def async_setup(self) -> None: """Subscribe and do the refresh.""" @@ -145,11 +143,8 @@ class ProtectData: return self.async_signal_device_id_update(self.api.bootstrap.nvr.id) - for device_type in DEVICES_THAT_ADOPT: - attr = f"{device_type.value}s" - devices: dict[str, ProtectDeviceModel] = getattr(self.api.bootstrap, attr) - for device_id in devices.keys(): - self.async_signal_device_id_update(device_id) + for device in async_get_devices(self.api, DEVICES_THAT_ADOPT): + self.async_signal_device_id_update(device.id) @callback def async_subscribe_device_id( @@ -188,3 +183,16 @@ class ProtectData: _LOGGER.debug("Updating device: %s", device_id) for update_callback in self._subscriptions[device_id]: update_callback() + + +@callback +def async_ufp_instance_for_config_entry_ids( + hass: HomeAssistant, config_entry_ids: set[str] +) -> ProtectApiClient | None: + """Find the UFP instance for the config entry ids.""" + domain_data = hass.data[DOMAIN] + for config_entry_id in config_entry_ids: + if config_entry_id in domain_data: + protect_data: ProtectData = domain_data[config_entry_id] + return protect_data.api + return None diff --git a/homeassistant/components/unifiprotect/entity.py b/homeassistant/components/unifiprotect/entity.py index f8ceaeec9e6..2911a861535 100644 --- a/homeassistant/components/unifiprotect/entity.py +++ b/homeassistant/components/unifiprotect/entity.py @@ -26,7 +26,7 @@ from homeassistant.helpers.entity import DeviceInfo, Entity, EntityDescription from .const import ATTR_EVENT_SCORE, DEFAULT_ATTRIBUTION, DEFAULT_BRAND, DOMAIN from .data import ProtectData from .models import ProtectRequiredKeysMixin -from .utils import get_nested_attr +from .utils import async_get_adoptable_devices_by_type, get_nested_attr _LOGGER = logging.getLogger(__name__) @@ -153,7 +153,9 @@ class ProtectDeviceEntity(Entity): """Update Entity object from Protect device.""" if self.data.last_update_success: assert self.device.model - devices = getattr(self.data.api.bootstrap, f"{self.device.model.value}s") + devices = async_get_adoptable_devices_by_type( + self.data.api, self.device.model + ) self.device = devices[self.device.id] is_connected = ( diff --git a/homeassistant/components/unifiprotect/services.py b/homeassistant/components/unifiprotect/services.py index f8aa446f857..828aa9ecfd7 100644 --- a/homeassistant/components/unifiprotect/services.py +++ b/homeassistant/components/unifiprotect/services.py @@ -24,7 +24,7 @@ from homeassistant.helpers.service import async_extract_referenced_entity_ids from homeassistant.util.read_only_dict import ReadOnlyDict from .const import ATTR_MESSAGE, DOMAIN -from .data import ProtectData +from .data import async_ufp_instance_for_config_entry_ids SERVICE_ADD_DOORBELL_TEXT = "add_doorbell_text" SERVICE_REMOVE_DOORBELL_TEXT = "remove_doorbell_text" @@ -59,18 +59,6 @@ CHIME_PAIRED_SCHEMA = vol.All( ) -def _async_ufp_instance_for_config_entry_ids( - hass: HomeAssistant, config_entry_ids: set[str] -) -> ProtectApiClient | None: - """Find the UFP instance for the config entry ids.""" - domain_data = hass.data[DOMAIN] - for config_entry_id in config_entry_ids: - if config_entry_id in domain_data: - protect_data: ProtectData = domain_data[config_entry_id] - return protect_data.api - return None - - @callback def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiClient: device_registry = dr.async_get(hass) @@ -81,7 +69,7 @@ def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiCl return _async_get_ufp_instance(hass, device_entry.via_device_id) config_entry_ids = device_entry.config_entries - if ufp_instance := _async_ufp_instance_for_config_entry_ids(hass, config_entry_ids): + if ufp_instance := async_ufp_instance_for_config_entry_ids(hass, config_entry_ids): return ufp_instance raise HomeAssistantError(f"No device found for device id: {device_id}") diff --git a/homeassistant/components/unifiprotect/utils.py b/homeassistant/components/unifiprotect/utils.py index 559cfd37660..fffe987db0f 100644 --- a/homeassistant/components/unifiprotect/utils.py +++ b/homeassistant/components/unifiprotect/utils.py @@ -1,13 +1,19 @@ """UniFi Protect Integration utils.""" from __future__ import annotations +from collections.abc import Generator, Iterable import contextlib from enum import Enum import socket from typing import Any +from pyunifiprotect import ProtectApiClient +from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel + from homeassistant.core import HomeAssistant, callback +from .const import ModelType + def get_nested_attr(obj: Any, attr: str) -> Any: """Fetch a nested attribute.""" @@ -51,3 +57,35 @@ async def _async_resolve(hass: HomeAssistant, host: str) -> str | None: None, ) return None + + +def async_get_devices_by_type( + api: ProtectApiClient, device_type: ModelType +) -> dict[str, ProtectDeviceModel]: + """Get devices by type.""" + devices: dict[str, ProtectDeviceModel] = getattr( + api.bootstrap, f"{device_type.value}s" + ) + return devices + + +def async_get_adoptable_devices_by_type( + api: ProtectApiClient, device_type: ModelType +) -> dict[str, ProtectAdoptableDeviceModel]: + """Get adoptable devices by type.""" + devices: dict[str, ProtectAdoptableDeviceModel] = getattr( + api.bootstrap, f"{device_type.value}s" + ) + return devices + + +@callback +def async_get_devices( + api: ProtectApiClient, model_type: Iterable[ModelType] +) -> Generator[ProtectDeviceModel, None, None]: + """Return all device by type.""" + return ( + device + for device_type in model_type + for device in async_get_devices_by_type(api, device_type).values() + ) diff --git a/tests/components/unifiprotect/test_init.py b/tests/components/unifiprotect/test_init.py index 95c2ee0b511..cf899d854fd 100644 --- a/tests/components/unifiprotect/test_init.py +++ b/tests/components/unifiprotect/test_init.py @@ -2,8 +2,10 @@ # pylint: disable=protected-access from __future__ import annotations +from collections.abc import Awaitable, Callable from unittest.mock import AsyncMock, patch +import aiohttp from pyunifiprotect import NotAuthorized, NvrError from pyunifiprotect.data import NVR, Light @@ -11,7 +13,8 @@ from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAI from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import Platform from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.setup import async_setup_component from . import _patch_discovery from .conftest import MockBootstrap, MockEntityFixture @@ -19,6 +22,22 @@ from .conftest import MockBootstrap, MockEntityFixture from tests.common import MockConfigEntry +async def remove_device( + ws_client: aiohttp.ClientWebSocketResponse, device_id: str, config_entry_id: str +) -> bool: + """Remove config entry from a device.""" + await ws_client.send_json( + { + "id": 5, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": config_entry_id, + "device_id": device_id, + } + ) + response = await ws_client.receive_json() + return response["success"] + + async def test_setup(hass: HomeAssistant, mock_entry: MockEntityFixture): """Test working setup of unifiprotect entry.""" @@ -321,3 +340,50 @@ async def test_migrate_reboot_button_fail( light = registry.async_get(f"{Platform.BUTTON}.test_light_1") assert light is not None assert light.unique_id == f"{light1.id}" + + +async def test_device_remove_devices( + hass: HomeAssistant, + mock_entry: MockEntityFixture, + mock_light: Light, + hass_ws_client: Callable[ + [HomeAssistant], Awaitable[aiohttp.ClientWebSocketResponse] + ], +) -> None: + """Test we can only remove a device that no longer exists.""" + assert await async_setup_component(hass, "config", {}) + + light1 = mock_light.copy() + light1._api = mock_entry.api + light1.name = "Test Light 1" + light1.id = "lightid1" + light1.mac = "AABBCCDDEEFF" + + mock_entry.api.bootstrap.lights = { + light1.id: light1, + } + + mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap) + light_entity_id = "light.test_light_1" + await hass.config_entries.async_setup(mock_entry.entry.entry_id) + await hass.async_block_till_done() + entry_id = mock_entry.entry.entry_id + + registry: er.EntityRegistry = er.async_get(hass) + entity = registry.entities[light_entity_id] + device_registry = dr.async_get(hass) + + live_device_entry = device_registry.async_get(entity.device_id) + assert ( + await remove_device(await hass_ws_client(hass), live_device_entry.id, entry_id) + is False + ) + + dead_device_entry = device_registry.async_get_or_create( + config_entry_id=entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "e9:88:e7:b8:b4:40")}, + ) + assert ( + await remove_device(await hass_ws_client(hass), dead_device_entry.id, entry_id) + is True + )