From ddd7e79ee9a98cbfb7faac42c5ff7dbba86c7e0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Sun, 2 May 2021 01:33:31 +0300 Subject: [PATCH] Improve device registry internal typing (#49924) --- homeassistant/helpers/device_registry.py | 94 ++++++++++++------------ 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 024b11476e7..a448fd1c198 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections import OrderedDict import logging import time -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, NamedTuple, cast import attr @@ -37,11 +37,6 @@ CONNECTION_NETWORK_MAC = "mac" CONNECTION_UPNP = "upnp" CONNECTION_ZIGBEE = "zigbee" -IDX_CONNECTIONS = "connections" -IDX_IDENTIFIERS = "identifiers" -REGISTERED_DEVICE = "registered" -DELETED_DEVICE = "deleted" - DISABLED_CONFIG_ENTRY = "config_entry" DISABLED_INTEGRATION = "integration" DISABLED_USER = "user" @@ -49,6 +44,11 @@ DISABLED_USER = "user" ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30 +class _DeviceIndex(NamedTuple): + identifiers: dict[tuple[str, ...], str] + connections: dict[tuple[str, str], str] + + @attr.s(slots=True, frozen=True) class DeviceEntry: """Device Registry Entry.""" @@ -133,12 +133,30 @@ def format_mac(mac: str) -> str: return mac +def _async_get_device_id_from_index( + devices_index: _DeviceIndex, + identifiers: set[tuple[str, ...]], + connections: set[tuple[str, str]] | None, +) -> str | None: + """Check if device has previously been registered.""" + for identifier in identifiers: + if identifier in devices_index.identifiers: + return devices_index.identifiers[identifier] + if not connections: + return None + for connection in _normalize_connections(connections): + if connection in devices_index.connections: + return devices_index.connections[connection] + return None + + class DeviceRegistry: """Class to hold a registry of devices.""" devices: dict[str, DeviceEntry] deleted_devices: dict[str, DeletedDeviceEntry] - _devices_index: dict[str, dict[str, dict[tuple[str, ...], str]]] + _registered_index: _DeviceIndex + _deleted_index: _DeviceIndex def __init__(self, hass: HomeAssistant) -> None: """Initialize the device registry.""" @@ -158,8 +176,8 @@ class DeviceRegistry: connections: set[tuple[str, str]] | None = None, ) -> DeviceEntry | None: """Check if device is registered.""" - device_id = self._async_get_device_id_from_index( - REGISTERED_DEVICE, identifiers, connections + device_id = _async_get_device_id_from_index( + self._registered_index, identifiers, connections ) if device_id is None: return None @@ -171,38 +189,20 @@ class DeviceRegistry: connections: set[tuple[str, str]] | None, ) -> DeletedDeviceEntry | None: """Check if device is deleted.""" - device_id = self._async_get_device_id_from_index( - DELETED_DEVICE, identifiers, connections + device_id = _async_get_device_id_from_index( + self._deleted_index, identifiers, connections ) if device_id is None: return None return self.deleted_devices[device_id] - def _async_get_device_id_from_index( - self, - index: str, - identifiers: set[tuple[str, ...]], - connections: set[tuple[str, str]] | None, - ) -> str | None: - """Check if device has previously been registered.""" - devices_index = self._devices_index[index] - for identifier in identifiers: - if identifier in devices_index[IDX_IDENTIFIERS]: - return devices_index[IDX_IDENTIFIERS][identifier] - if not connections: - return None - for connection in _normalize_connections(connections): - if connection in devices_index[IDX_CONNECTIONS]: - return devices_index[IDX_CONNECTIONS][connection] - return None - def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None: """Add a device and index it.""" if isinstance(device, DeletedDeviceEntry): - devices_index = self._devices_index[DELETED_DEVICE] + devices_index = self._deleted_index self.deleted_devices[device.id] = device else: - devices_index = self._devices_index[REGISTERED_DEVICE] + devices_index = self._registered_index self.devices[device.id] = device _add_device_to_index(devices_index, device) @@ -210,10 +210,10 @@ class DeviceRegistry: def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None: """Remove a device and remove it from the index.""" if isinstance(device, DeletedDeviceEntry): - devices_index = self._devices_index[DELETED_DEVICE] + devices_index = self._deleted_index self.deleted_devices.pop(device.id) else: - devices_index = self._devices_index[REGISTERED_DEVICE] + devices_index = self._registered_index self.devices.pop(device.id) _remove_device_from_index(devices_index, device) @@ -222,24 +222,22 @@ class DeviceRegistry: """Update a device and the index.""" self.devices[new_device.id] = new_device - devices_index = self._devices_index[REGISTERED_DEVICE] + devices_index = self._registered_index _remove_device_from_index(devices_index, old_device) _add_device_to_index(devices_index, new_device) def _clear_index(self) -> None: """Clear the index.""" - self._devices_index = { - REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, - DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, - } + self._registered_index = _DeviceIndex(identifiers={}, connections={}) + self._deleted_index = _DeviceIndex(identifiers={}, connections={}) def _rebuild_index(self) -> None: """Create the index after loading devices.""" self._clear_index() for device in self.devices.values(): - _add_device_to_index(self._devices_index[REGISTERED_DEVICE], device) + _add_device_to_index(self._registered_index, device) for deleted_device in self.deleted_devices.values(): - _add_device_to_index(self._devices_index[DELETED_DEVICE], deleted_device) + _add_device_to_index(self._deleted_index, deleted_device) @callback def async_get_or_create( @@ -786,24 +784,24 @@ def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str, def _add_device_to_index( - devices_index: dict[str, dict[tuple[str, ...], str]], + devices_index: _DeviceIndex, device: DeviceEntry | DeletedDeviceEntry, ) -> None: """Add a device to the index.""" for identifier in device.identifiers: - devices_index[IDX_IDENTIFIERS][identifier] = device.id + devices_index.identifiers[identifier] = device.id for connection in device.connections: - devices_index[IDX_CONNECTIONS][connection] = device.id + devices_index.connections[connection] = device.id def _remove_device_from_index( - devices_index: dict[str, dict[tuple[str, ...], str]], + devices_index: _DeviceIndex, device: DeviceEntry | DeletedDeviceEntry, ) -> None: """Remove a device from the index.""" for identifier in device.identifiers: - if identifier in devices_index[IDX_IDENTIFIERS]: - del devices_index[IDX_IDENTIFIERS][identifier] + if identifier in devices_index.identifiers: + del devices_index.identifiers[identifier] for connection in device.connections: - if connection in devices_index[IDX_CONNECTIONS]: - del devices_index[IDX_CONNECTIONS][connection] + if connection in devices_index.connections: + del devices_index.connections[connection]