Add support for async_remove_config_entry_device to unifiprotect (#72742)

* Add support for async_remove_config_entry_device to unifiprotect

* tweaks

* tweaks

* more cleanups

* more cleanups

* fix unhelpful auto import

* add coverage

* fix mac formatting

* collapse logic
This commit is contained in:
J. Nick Koston 2022-06-02 05:26:08 -10:00 committed by GitHub
parent 9192d0e972
commit f1a31d8d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 152 additions and 32 deletions

View File

@ -21,7 +21,7 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_create_clientsession
from .const import (
@ -35,9 +35,10 @@ from .const import (
OUTDATED_LOG_MESSAGE,
PLATFORMS,
)
from .data import ProtectData
from .data import ProtectData, async_ufp_instance_for_config_entry_ids
from .discovery import async_start_discovery
from .services import async_cleanup_services, async_setup_services
from .utils import _async_unifi_mac_from_hass, async_get_devices
_LOGGER = logging.getLogger(__name__)
@ -166,3 +167,20 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async_cleanup_services(hass)
return bool(unload_ok)
async def async_remove_config_entry_device(
hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry
) -> bool:
"""Remove ufp config entry from a device."""
unifi_macs = {
_async_unifi_mac_from_hass(connection[1])
for connection in device_entry.connections
if connection[0] == dr.CONNECTION_NETWORK_MAC
}
api = async_ufp_instance_for_config_entry_ids(hass, {config_entry.entry_id})
assert api is not None
return api.bootstrap.nvr.mac not in unifi_macs and not any(
device.mac in unifi_macs
for device in async_get_devices(api, DEVICES_THAT_ADOPT)
)

View File

@ -14,13 +14,14 @@ from pyunifiprotect.data import (
ModelType,
WSSubscriptionMessage,
)
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_track_time_interval
from .const import CONF_DISABLE_RTSP, DEVICES_THAT_ADOPT, DEVICES_WITH_ENTITIES
from .const import CONF_DISABLE_RTSP, DEVICES_THAT_ADOPT, DEVICES_WITH_ENTITIES, DOMAIN
from .utils import async_get_adoptable_devices_by_type, async_get_devices
_LOGGER = logging.getLogger(__name__)
@ -58,13 +59,10 @@ class ProtectData:
self, device_types: Iterable[ModelType]
) -> Generator[ProtectAdoptableDeviceModel, None, None]:
"""Get all devices matching types."""
for device_type in device_types:
attr = f"{device_type.value}s"
devices: dict[str, ProtectAdoptableDeviceModel] = getattr(
self.api.bootstrap, attr
)
yield from devices.values()
yield from async_get_adoptable_devices_by_type(
self.api, device_type
).values()
async def async_setup(self) -> None:
"""Subscribe and do the refresh."""
@ -145,11 +143,8 @@ class ProtectData:
return
self.async_signal_device_id_update(self.api.bootstrap.nvr.id)
for device_type in DEVICES_THAT_ADOPT:
attr = f"{device_type.value}s"
devices: dict[str, ProtectDeviceModel] = getattr(self.api.bootstrap, attr)
for device_id in devices.keys():
self.async_signal_device_id_update(device_id)
for device in async_get_devices(self.api, DEVICES_THAT_ADOPT):
self.async_signal_device_id_update(device.id)
@callback
def async_subscribe_device_id(
@ -188,3 +183,16 @@ class ProtectData:
_LOGGER.debug("Updating device: %s", device_id)
for update_callback in self._subscriptions[device_id]:
update_callback()
@callback
def async_ufp_instance_for_config_entry_ids(
hass: HomeAssistant, config_entry_ids: set[str]
) -> ProtectApiClient | None:
"""Find the UFP instance for the config entry ids."""
domain_data = hass.data[DOMAIN]
for config_entry_id in config_entry_ids:
if config_entry_id in domain_data:
protect_data: ProtectData = domain_data[config_entry_id]
return protect_data.api
return None

View File

@ -26,7 +26,7 @@ from homeassistant.helpers.entity import DeviceInfo, Entity, EntityDescription
from .const import ATTR_EVENT_SCORE, DEFAULT_ATTRIBUTION, DEFAULT_BRAND, DOMAIN
from .data import ProtectData
from .models import ProtectRequiredKeysMixin
from .utils import get_nested_attr
from .utils import async_get_adoptable_devices_by_type, get_nested_attr
_LOGGER = logging.getLogger(__name__)
@ -153,7 +153,9 @@ class ProtectDeviceEntity(Entity):
"""Update Entity object from Protect device."""
if self.data.last_update_success:
assert self.device.model
devices = getattr(self.data.api.bootstrap, f"{self.device.model.value}s")
devices = async_get_adoptable_devices_by_type(
self.data.api, self.device.model
)
self.device = devices[self.device.id]
is_connected = (

View File

@ -24,7 +24,7 @@ from homeassistant.helpers.service import async_extract_referenced_entity_ids
from homeassistant.util.read_only_dict import ReadOnlyDict
from .const import ATTR_MESSAGE, DOMAIN
from .data import ProtectData
from .data import async_ufp_instance_for_config_entry_ids
SERVICE_ADD_DOORBELL_TEXT = "add_doorbell_text"
SERVICE_REMOVE_DOORBELL_TEXT = "remove_doorbell_text"
@ -59,18 +59,6 @@ CHIME_PAIRED_SCHEMA = vol.All(
)
def _async_ufp_instance_for_config_entry_ids(
hass: HomeAssistant, config_entry_ids: set[str]
) -> ProtectApiClient | None:
"""Find the UFP instance for the config entry ids."""
domain_data = hass.data[DOMAIN]
for config_entry_id in config_entry_ids:
if config_entry_id in domain_data:
protect_data: ProtectData = domain_data[config_entry_id]
return protect_data.api
return None
@callback
def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiClient:
device_registry = dr.async_get(hass)
@ -81,7 +69,7 @@ def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiCl
return _async_get_ufp_instance(hass, device_entry.via_device_id)
config_entry_ids = device_entry.config_entries
if ufp_instance := _async_ufp_instance_for_config_entry_ids(hass, config_entry_ids):
if ufp_instance := async_ufp_instance_for_config_entry_ids(hass, config_entry_ids):
return ufp_instance
raise HomeAssistantError(f"No device found for device id: {device_id}")

View File

@ -1,13 +1,19 @@
"""UniFi Protect Integration utils."""
from __future__ import annotations
from collections.abc import Generator, Iterable
import contextlib
from enum import Enum
import socket
from typing import Any
from pyunifiprotect import ProtectApiClient
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel
from homeassistant.core import HomeAssistant, callback
from .const import ModelType
def get_nested_attr(obj: Any, attr: str) -> Any:
"""Fetch a nested attribute."""
@ -51,3 +57,35 @@ async def _async_resolve(hass: HomeAssistant, host: str) -> str | None:
None,
)
return None
def async_get_devices_by_type(
api: ProtectApiClient, device_type: ModelType
) -> dict[str, ProtectDeviceModel]:
"""Get devices by type."""
devices: dict[str, ProtectDeviceModel] = getattr(
api.bootstrap, f"{device_type.value}s"
)
return devices
def async_get_adoptable_devices_by_type(
api: ProtectApiClient, device_type: ModelType
) -> dict[str, ProtectAdoptableDeviceModel]:
"""Get adoptable devices by type."""
devices: dict[str, ProtectAdoptableDeviceModel] = getattr(
api.bootstrap, f"{device_type.value}s"
)
return devices
@callback
def async_get_devices(
api: ProtectApiClient, model_type: Iterable[ModelType]
) -> Generator[ProtectDeviceModel, None, None]:
"""Return all device by type."""
return (
device
for device_type in model_type
for device in async_get_devices_by_type(api, device_type).values()
)

View File

@ -2,8 +2,10 @@
# pylint: disable=protected-access
from __future__ import annotations
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock, patch
import aiohttp
from pyunifiprotect import NotAuthorized, NvrError
from pyunifiprotect.data import NVR, Light
@ -11,7 +13,8 @@ from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAI
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component
from . import _patch_discovery
from .conftest import MockBootstrap, MockEntityFixture
@ -19,6 +22,22 @@ from .conftest import MockBootstrap, MockEntityFixture
from tests.common import MockConfigEntry
async def remove_device(
ws_client: aiohttp.ClientWebSocketResponse, device_id: str, config_entry_id: str
) -> bool:
"""Remove config entry from a device."""
await ws_client.send_json(
{
"id": 5,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": config_entry_id,
"device_id": device_id,
}
)
response = await ws_client.receive_json()
return response["success"]
async def test_setup(hass: HomeAssistant, mock_entry: MockEntityFixture):
"""Test working setup of unifiprotect entry."""
@ -321,3 +340,50 @@ async def test_migrate_reboot_button_fail(
light = registry.async_get(f"{Platform.BUTTON}.test_light_1")
assert light is not None
assert light.unique_id == f"{light1.id}"
async def test_device_remove_devices(
hass: HomeAssistant,
mock_entry: MockEntityFixture,
mock_light: Light,
hass_ws_client: Callable[
[HomeAssistant], Awaitable[aiohttp.ClientWebSocketResponse]
],
) -> None:
"""Test we can only remove a device that no longer exists."""
assert await async_setup_component(hass, "config", {})
light1 = mock_light.copy()
light1._api = mock_entry.api
light1.name = "Test Light 1"
light1.id = "lightid1"
light1.mac = "AABBCCDDEEFF"
mock_entry.api.bootstrap.lights = {
light1.id: light1,
}
mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap)
light_entity_id = "light.test_light_1"
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
await hass.async_block_till_done()
entry_id = mock_entry.entry.entry_id
registry: er.EntityRegistry = er.async_get(hass)
entity = registry.entities[light_entity_id]
device_registry = dr.async_get(hass)
live_device_entry = device_registry.async_get(entity.device_id)
assert (
await remove_device(await hass_ws_client(hass), live_device_entry.id, entry_id)
is False
)
dead_device_entry = device_registry.async_get_or_create(
config_entry_id=entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "e9:88:e7:b8:b4:40")},
)
assert (
await remove_device(await hass_ws_client(hass), dead_device_entry.id, entry_id)
is True
)