Simplify device registry (#77715)

* Simplify device registry

* Fix test fixture

* Update homeassistant/helpers/device_registry.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update device_registry.py

* Remove dead code

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2022-09-03 12:50:55 +02:00 committed by GitHub
parent 7e100b64ea
commit 56278a4421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 129 deletions

View File

@ -1,11 +1,11 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import UserDict
from collections.abc import Coroutine from collections.abc import Coroutine
import logging import logging
import time import time
from typing import TYPE_CHECKING, Any, NamedTuple, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
import attr import attr
@ -48,11 +48,6 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
RUNTIME_ONLY_ATTRS = {"suggested_area"} RUNTIME_ONLY_ATTRS = {"suggested_area"}
class _DeviceIndex(NamedTuple):
identifiers: dict[tuple[str, str], str]
connections: dict[tuple[str, str], str]
class DeviceEntryDisabler(StrEnum): class DeviceEntryDisabler(StrEnum):
"""What disabled a device entry.""" """What disabled a device entry."""
@ -149,23 +144,6 @@ def format_mac(mac: str) -> str:
return mac return mac
def _async_get_device_id_from_index(
devices_index: _DeviceIndex,
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> str | None:
"""Check if device has previously been registered."""
for identifier in identifiers:
if identifier in devices_index.identifiers:
return devices_index.identifiers[identifier]
if not connections:
return None
for connection in _normalize_connections(connections):
if connection in devices_index.connections:
return devices_index.connections[connection]
return None
class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
"""Store entity registry data.""" """Store entity registry data."""
@ -210,13 +188,69 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
return old_data return old_data
_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry)
class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
"""Container for device registry items, maps device id -> entry.
Maintains two additional indexes:
- (connection_type, connection identifier) -> entry
- (DOMAIN, identifier) -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._connections: dict[tuple[str, str], _EntryTypeT] = {}
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
"""Add an item."""
if key in self:
old_entry = self[key]
for connection in old_entry.connections:
del self._connections[connection]
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
# type ignore linked to mypy issue: https://github.com/python/mypy/issues/13596
super().__setitem__(key, entry) # type: ignore[assignment]
for connection in entry.connections:
self._connections[connection] = entry
for identifier in entry.identifiers:
self._identifiers[identifier] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
for connection in entry.connections:
del self._connections[connection]
for identifier in entry.identifiers:
del self._identifiers[identifier]
super().__delitem__(key)
def get_entry(
self,
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> _EntryTypeT | None:
"""Get entry from identifiers or connections."""
for identifier in identifiers:
if identifier in self._identifiers:
return self._identifiers[identifier]
if not connections:
return None
for connection in _normalize_connections(connections):
if connection in self._connections:
return self._connections[connection]
return None
class DeviceRegistry: class DeviceRegistry:
"""Class to hold a registry of devices.""" """Class to hold a registry of devices."""
devices: dict[str, DeviceEntry] devices: DeviceRegistryItems[DeviceEntry]
deleted_devices: dict[str, DeletedDeviceEntry] deleted_devices: DeviceRegistryItems[DeletedDeviceEntry]
_registered_index: _DeviceIndex
_deleted_index: _DeviceIndex
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
@ -228,7 +262,6 @@ class DeviceRegistry:
atomic_writes=True, atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR, minor_version=STORAGE_VERSION_MINOR,
) )
self._clear_index()
@callback @callback
def async_get(self, device_id: str) -> DeviceEntry | None: def async_get(self, device_id: str) -> DeviceEntry | None:
@ -242,12 +275,7 @@ class DeviceRegistry:
connections: set[tuple[str, str]] | None = None, connections: set[tuple[str, str]] | None = None,
) -> DeviceEntry | None: ) -> DeviceEntry | None:
"""Check if device is registered.""" """Check if device is registered."""
device_id = _async_get_device_id_from_index( return self.devices.get_entry(identifiers, connections)
self._registered_index, identifiers, connections
)
if device_id is None:
return None
return self.devices[device_id]
def _async_get_deleted_device( def _async_get_deleted_device(
self, self,
@ -255,55 +283,7 @@ class DeviceRegistry:
connections: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None,
) -> DeletedDeviceEntry | None: ) -> DeletedDeviceEntry | None:
"""Check if device is deleted.""" """Check if device is deleted."""
device_id = _async_get_device_id_from_index( return self.deleted_devices.get_entry(identifiers, connections)
self._deleted_index, identifiers, connections
)
if device_id is None:
return None
return self.deleted_devices[device_id]
def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Add a device and index it."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._deleted_index
self.deleted_devices[device.id] = device
else:
devices_index = self._registered_index
self.devices[device.id] = device
_add_device_to_index(devices_index, device)
def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Remove a device and remove it from the index."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._deleted_index
self.deleted_devices.pop(device.id)
else:
devices_index = self._registered_index
self.devices.pop(device.id)
_remove_device_from_index(devices_index, device)
def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None:
"""Update a device and the index."""
self.devices[new_device.id] = new_device
devices_index = self._registered_index
_remove_device_from_index(devices_index, old_device)
_add_device_to_index(devices_index, new_device)
def _clear_index(self) -> None:
"""Clear the index."""
self._registered_index = _DeviceIndex(identifiers={}, connections={})
self._deleted_index = _DeviceIndex(identifiers={}, connections={})
def _rebuild_index(self) -> None:
"""Create the index after loading devices."""
self._clear_index()
for device in self.devices.values():
_add_device_to_index(self._registered_index, device)
for deleted_device in self.deleted_devices.values():
_add_device_to_index(self._deleted_index, deleted_device)
@callback @callback
def async_get_or_create( def async_get_or_create(
@ -346,11 +326,11 @@ class DeviceRegistry:
if deleted_device is None: if deleted_device is None:
device = DeviceEntry(is_new=True) device = DeviceEntry(is_new=True)
else: else:
self._remove_device(deleted_device) self.deleted_devices.pop(deleted_device.id)
device = deleted_device.to_device_entry( device = deleted_device.to_device_entry(
config_entry_id, connections, identifiers config_entry_id, connections, identifiers
) )
self._add_device(device) self.devices[device.id] = device
if default_manufacturer is not UNDEFINED and device.manufacturer is None: if default_manufacturer is not UNDEFINED and device.manufacturer is None:
manufacturer = default_manufacturer manufacturer = default_manufacturer
@ -516,7 +496,7 @@ class DeviceRegistry:
return old return old
new = attr.evolve(old, **new_values) new = attr.evolve(old, **new_values)
self._update_device(old, new) self.devices[device_id] = new
# If its only run time attributes (suggested_area) # If its only run time attributes (suggested_area)
# that do not get saved we do not want to write # that do not get saved we do not want to write
@ -542,16 +522,13 @@ class DeviceRegistry:
@callback @callback
def async_remove_device(self, device_id: str) -> None: def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry.""" """Remove a device from the device registry."""
device = self.devices[device_id] device = self.devices.pop(device_id)
self._remove_device(device) self.deleted_devices[device_id] = DeletedDeviceEntry(
self._add_device( config_entries=device.config_entries,
DeletedDeviceEntry( connections=device.connections,
config_entries=device.config_entries, identifiers=device.identifiers,
connections=device.connections, id=device.id,
identifiers=device.identifiers, orphaned_timestamp=None,
id=device.id,
orphaned_timestamp=None,
)
) )
for other_device in list(self.devices.values()): for other_device in list(self.devices.values()):
if other_device.via_device_id == device_id: if other_device.via_device_id == device_id:
@ -567,8 +544,8 @@ class DeviceRegistry:
data = await self._store.async_load() data = await self._store.async_load()
devices = OrderedDict() devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems()
deleted_devices = OrderedDict() deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems()
if data is not None: if data is not None:
for device in data["devices"]: for device in data["devices"]:
@ -607,7 +584,6 @@ class DeviceRegistry:
self.devices = devices self.devices = devices
self.deleted_devices = deleted_devices self.deleted_devices = deleted_devices
self._rebuild_index()
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -692,7 +668,7 @@ class DeviceRegistry:
deleted_device.orphaned_timestamp + ORPHANED_DEVICE_KEEP_SECONDS deleted_device.orphaned_timestamp + ORPHANED_DEVICE_KEEP_SECONDS
< now_time < now_time
): ):
self._remove_device(deleted_device) del self.deleted_devices[deleted_device.id]
@callback @callback
def async_clear_area_id(self, area_id: str) -> None: def async_clear_area_id(self, area_id: str) -> None:
@ -879,27 +855,3 @@ def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str,
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
for key, value in connections for key, value in connections
} }
def _add_device_to_index(
devices_index: _DeviceIndex,
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Add a device to the index."""
for identifier in device.identifiers:
devices_index.identifiers[identifier] = device.id
for connection in device.connections:
devices_index.connections[connection] = device.id
def _remove_device_from_index(
devices_index: _DeviceIndex,
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Remove a device from the index."""
for identifier in device.identifiers:
if identifier in devices_index.identifiers:
del devices_index.identifiers[identifier]
for connection in device.connections:
if connection in devices_index.connections:
del devices_index.connections[connection]

View File

@ -469,12 +469,15 @@ def mock_area_registry(hass, mock_entries=None):
return registry return registry
def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None): def mock_device_registry(hass, mock_entries=None):
"""Mock the Device Registry.""" """Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass) registry = device_registry.DeviceRegistry(hass)
registry.devices = mock_entries or OrderedDict() registry.devices = device_registry.DeviceRegistryItems()
registry.deleted_devices = mock_deleted_entries or OrderedDict() if mock_entries is None:
registry._rebuild_index() mock_entries = {}
for key, entry in mock_entries.items():
registry.devices[key] = entry
registry.deleted_devices = device_registry.DeviceRegistryItems()
hass.data[device_registry.DATA_REGISTRY] = registry hass.data[device_registry.DATA_REGISTRY] = registry
return registry return registry