diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 3f2e8a94b7c..5eb8a37176a 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -451,6 +451,7 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self._index: dict[tuple[str, str, str], str] = {} self._config_entry_id_index: dict[str, list[str]] = {} self._device_id_index: dict[str, list[str]] = {} + self._area_id_index: dict[str, list[str]] = {} def values(self) -> ValuesView[RegistryEntry]: """Return the underlying values to avoid __iter__ overhead.""" @@ -468,22 +469,34 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self._config_entry_id_index.setdefault(config_entry_id, []).append(key) if (device_id := entry.device_id) is not None: self._device_id_index.setdefault(device_id, []).append(key) + if (area_id := entry.area_id) is not None: + self._area_id_index.setdefault(area_id, []).append(key) + + def _unindex_entry_value( + self, key: str, value: str, index: dict[str, list[str]] + ) -> None: + """Unindex an entry value. + + key is the entry key + value is the value to unindex such as config_entry_id or device_id. + index is the index to unindex from. + """ + entries = index[value] + entries.remove(key) + if not entries: + del index[value] def _unindex_entry(self, key: str) -> None: """Unindex an entry.""" entry = self.data[key] del self._entry_ids[entry.id] del self._index[(entry.domain, entry.platform, entry.unique_id)] - if (config_entry_id := entry.config_entry_id) is not None: - entries = self._config_entry_id_index[config_entry_id] - entries.remove(key) - if not entries: - del self._config_entry_id_index[config_entry_id] - if (device_id := entry.device_id) is not None: - entries = self._device_id_index[device_id] - entries.remove(key) - if not entries: - del self._device_id_index[device_id] + if config_entry_id := entry.config_entry_id: + self._unindex_entry_value(key, config_entry_id, self._config_entry_id_index) + if device_id := entry.device_id: + self._unindex_entry_value(key, device_id, self._device_id_index) + if area_id := entry.area_id: + self._unindex_entry_value(key, area_id, self._area_id_index) def __delitem__(self, key: str) -> None: """Remove an item.""" @@ -518,6 +531,11 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): data[key] for key in self._config_entry_id_index.get(config_entry_id, ()) ] + def get_entries_for_area_id(self, area_id: str) -> list[RegistryEntry]: + """Get entries for area.""" + data = self.data + return [data[key] for key in self._area_id_index.get(area_id, ())] + class EntityRegistry: """Class to hold a registry of entities.""" @@ -1266,7 +1284,7 @@ def async_entries_for_area( registry: EntityRegistry, area_id: str ) -> list[RegistryEntry]: """Return entries that match an area.""" - return [entry for entry in registry.entities.values() if entry.area_id == area_id] + return registry.entities.get_entries_for_area_id(area_id) @callback