From d7910841ef8b33546e026c6007efb95bbf15aaff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 13 Jan 2024 09:49:41 -1000 Subject: [PATCH] Add an index for devices and config entries to the entity registry (#107516) * Add an index for devices and config entries to the entity registry * fixes * tweak * use a list for now since the tests check order --- homeassistant/helpers/entity_registry.py | 58 ++++++++++++++++++------ 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 65ae1a8e9e5..1f9da1969f2 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -436,9 +436,11 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): class EntityRegistryItems(UserDict[str, RegistryEntry]): """Container for entity registry items, maps entity_id -> entry. - Maintains two additional indexes: + Maintains four additional indexes: - id -> entry - (domain, platform, unique_id) -> entity_id + - config_entry_id -> list[key] + - device_id -> list[key] """ def __init__(self) -> None: @@ -446,6 +448,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): super().__init__() self._entry_ids: dict[str, RegistryEntry] = {} self._index: dict[tuple[str, str, str], str] = {} + self._config_entry_id_index: dict[str, list[str]] = {} + self._device_id_index: dict[str, list[str]] = {} def values(self) -> ValuesView[RegistryEntry]: """Return the underlying values to avoid __iter__ overhead.""" @@ -455,18 +459,34 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): """Add an item.""" data = self.data if key in data: - old_entry = data[key] - del self._entry_ids[old_entry.id] - del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)] + self._unindex_entry(key) data[key] = entry self._entry_ids[entry.id] = entry self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id + if (config_entry_id := entry.config_entry_id) is not None: + self._config_entry_id_index.setdefault(config_entry_id, []).append(key) + if (device_id := entry.device_id) is not None: + self._device_id_index.setdefault(device_id, []).append(key) + + def _unindex_entry(self, key: str) -> None: + """Unindex an entry.""" + entry = self.data[key] + del self._entry_ids[entry.id] + del self._index[(entry.domain, entry.platform, entry.unique_id)] + if (config_entry_id := entry.config_entry_id) is not None: + entries = self._config_entry_id_index[config_entry_id] + entries.remove(key) + if not entries: + del self._config_entry_id_index[config_entry_id] + if (device_id := entry.device_id) is not None: + entries = self._device_id_index[device_id] + entries.remove(key) + if not entries: + del self._device_id_index[device_id] def __delitem__(self, key: str) -> None: """Remove an item.""" - entry = self[key] - del self._entry_ids[entry.id] - del self._index[(entry.domain, entry.platform, entry.unique_id)] + self._unindex_entry(key) super().__delitem__(key) def get_entity_id(self, key: tuple[str, str, str]) -> str | None: @@ -477,6 +497,19 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): """Get entry from id.""" return self._entry_ids.get(key) + def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]: + """Get entries for device.""" + return [self.data[key] for key in self._device_id_index.get(device_id, ())] + + def get_entries_for_config_entry_id( + self, config_entry_id: str + ) -> list[RegistryEntry]: + """Get entries for config entry.""" + return [ + self.data[key] + for key in self._config_entry_id_index.get(config_entry_id, ()) + ] + class EntityRegistry: """Class to hold a registry of entities.""" @@ -1217,9 +1250,8 @@ def async_entries_for_device( """Return entries that match a device.""" return [ entry - for entry in registry.entities.values() - if entry.device_id == device_id - and (not entry.disabled_by or include_disabled_entities) + for entry in registry.entities.get_entries_for_device_id(device_id) + if (not entry.disabled_by or include_disabled_entities) ] @@ -1236,11 +1268,7 @@ def async_entries_for_config_entry( registry: EntityRegistry, config_entry_id: str ) -> list[RegistryEntry]: """Return entries that match a config entry.""" - return [ - entry - for entry in registry.entities.values() - if entry.config_entry_id == config_entry_id - ] + return registry.entities.get_entries_for_config_entry_id(config_entry_id) @callback