Add an index for devices and config entries to the entity registry (#107516)

* Add an index for devices and config entries to the entity registry

* fixes

* tweak

* use a list for now since the tests check order
This commit is contained in:
J. Nick Koston 2024-01-13 09:49:41 -10:00 committed by GitHub
parent 5d3e069655
commit d7910841ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -436,9 +436,11 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
class EntityRegistryItems(UserDict[str, RegistryEntry]): class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry. """Container for entity registry items, maps entity_id -> entry.
Maintains two additional indexes: Maintains four additional indexes:
- id -> entry - id -> entry
- (domain, platform, unique_id) -> entity_id - (domain, platform, unique_id) -> entity_id
- config_entry_id -> list[key]
- device_id -> list[key]
""" """
def __init__(self) -> None: def __init__(self) -> None:
@ -446,6 +448,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
super().__init__() super().__init__()
self._entry_ids: dict[str, RegistryEntry] = {} self._entry_ids: dict[str, RegistryEntry] = {}
self._index: dict[tuple[str, str, str], str] = {} self._index: dict[tuple[str, str, str], str] = {}
self._config_entry_id_index: dict[str, list[str]] = {}
self._device_id_index: dict[str, list[str]] = {}
def values(self) -> ValuesView[RegistryEntry]: def values(self) -> ValuesView[RegistryEntry]:
"""Return the underlying values to avoid __iter__ overhead.""" """Return the underlying values to avoid __iter__ overhead."""
@ -455,18 +459,34 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Add an item.""" """Add an item."""
data = self.data data = self.data
if key in data: if key in data:
old_entry = data[key] self._unindex_entry(key)
del self._entry_ids[old_entry.id]
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
data[key] = entry data[key] = entry
self._entry_ids[entry.id] = entry self._entry_ids[entry.id] = entry
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
if (config_entry_id := entry.config_entry_id) is not None:
self._config_entry_id_index.setdefault(config_entry_id, []).append(key)
if (device_id := entry.device_id) is not None:
self._device_id_index.setdefault(device_id, []).append(key)
def _unindex_entry(self, key: str) -> None:
"""Unindex an entry."""
entry = self.data[key]
del self._entry_ids[entry.id]
del self._index[(entry.domain, entry.platform, entry.unique_id)]
if (config_entry_id := entry.config_entry_id) is not None:
entries = self._config_entry_id_index[config_entry_id]
entries.remove(key)
if not entries:
del self._config_entry_id_index[config_entry_id]
if (device_id := entry.device_id) is not None:
entries = self._device_id_index[device_id]
entries.remove(key)
if not entries:
del self._device_id_index[device_id]
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
"""Remove an item.""" """Remove an item."""
entry = self[key] self._unindex_entry(key)
del self._entry_ids[entry.id]
del self._index[(entry.domain, entry.platform, entry.unique_id)]
super().__delitem__(key) super().__delitem__(key)
def get_entity_id(self, key: tuple[str, str, str]) -> str | None: def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
@ -477,6 +497,19 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Get entry from id.""" """Get entry from id."""
return self._entry_ids.get(key) return self._entry_ids.get(key)
def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]:
"""Get entries for device."""
return [self.data[key] for key in self._device_id_index.get(device_id, ())]
def get_entries_for_config_entry_id(
self, config_entry_id: str
) -> list[RegistryEntry]:
"""Get entries for config entry."""
return [
self.data[key]
for key in self._config_entry_id_index.get(config_entry_id, ())
]
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
@ -1217,9 +1250,8 @@ def async_entries_for_device(
"""Return entries that match a device.""" """Return entries that match a device."""
return [ return [
entry entry
for entry in registry.entities.values() for entry in registry.entities.get_entries_for_device_id(device_id)
if entry.device_id == device_id if (not entry.disabled_by or include_disabled_entities)
and (not entry.disabled_by or include_disabled_entities)
] ]
@ -1236,11 +1268,7 @@ def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str registry: EntityRegistry, config_entry_id: str
) -> list[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match a config entry.""" """Return entries that match a config entry."""
return [ return registry.entities.get_entries_for_config_entry_id(config_entry_id)
entry
for entry in registry.entities.values()
if entry.config_entry_id == config_entry_id
]
@callback @callback