diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 95f889281fc..b3da01114d3 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -1,8 +1,8 @@ """Provide a way to connect devices to one physical location.""" from __future__ import annotations -from collections import OrderedDict -from collections.abc import Iterable, MutableMapping +from collections import UserDict +from collections.abc import Iterable, ValuesView import dataclasses from typing import Any, Literal, TypedDict, cast @@ -39,6 +39,52 @@ class AreaEntry: picture: str | None +class AreaRegistryItems(UserDict[str, AreaEntry]): + """Container for area registry items, maps area id -> entry. + + Maintains an additional index: + - normalized name -> entry + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._normalized_names: dict[str, AreaEntry] = {} + + def values(self) -> ValuesView[AreaEntry]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + + def __setitem__(self, key: str, entry: AreaEntry) -> None: + """Add an item.""" + data = self.data + normalized_name = normalize_area_name(entry.name) + + if key in data: + old_entry = data[key] + if ( + normalized_name != old_entry.normalized_name + and normalized_name in self._normalized_names + ): + raise ValueError( + f"The name {entry.name} ({normalized_name}) is already in use" + ) + del self._normalized_names[old_entry.normalized_name] + data[key] = entry + self._normalized_names[normalized_name] = entry + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + entry = self[key] + normalized_name = normalize_area_name(entry.name) + del self._normalized_names[normalized_name] + super().__delitem__(key) + + def get_area_by_name(self, name: str) -> AreaEntry | None: + """Get area by name.""" + return self._normalized_names.get(normalize_area_name(name)) + + class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): """Store area registry data.""" @@ -69,10 +115,12 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): class AreaRegistry: """Class to hold a registry of areas.""" + areas: AreaRegistryItems + _area_data: dict[str, AreaEntry] + def __init__(self, hass: HomeAssistant) -> None: """Initialize the area registry.""" self.hass = hass - self.areas: MutableMapping[str, AreaEntry] = {} self._store = AreaRegistryStore( hass, STORAGE_VERSION_MAJOR, @@ -80,20 +128,20 @@ class AreaRegistry: atomic_writes=True, minor_version=STORAGE_VERSION_MINOR, ) - self._normalized_name_area_idx: dict[str, str] = {} @callback def async_get_area(self, area_id: str) -> AreaEntry | None: - """Get area by id.""" - return self.areas.get(area_id) + """Get area by id. + + We retrieve the DeviceEntry from the underlying dict to avoid + the overhead of the UserDict __getitem__. + """ + return self._area_data.get(area_id) @callback def async_get_area_by_name(self, name: str) -> AreaEntry | None: """Get area by name.""" - normalized_name = normalize_area_name(name) - if normalized_name not in self._normalized_name_area_idx: - return None - return self.areas[self._normalized_name_area_idx[normalized_name]] + return self.areas.get_area_by_name(name) @callback def async_list_areas(self) -> Iterable[AreaEntry]: @@ -131,7 +179,6 @@ class AreaRegistry: ) assert area.id is not None self.areas[area.id] = area - self._normalized_name_area_idx[normalized_name] = area.id self.async_schedule_save() self.hass.bus.async_fire( EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id} @@ -141,14 +188,12 @@ class AreaRegistry: @callback def async_delete(self, area_id: str) -> None: """Delete area.""" - area = self.areas[area_id] device_registry = dr.async_get(self.hass) entity_registry = er.async_get(self.hass) device_registry.async_clear_area_id(area_id) entity_registry.async_clear_area_id(area_id) del self.areas[area_id] - del self._normalized_name_area_idx[area.normalized_name] self.hass.bus.async_fire( EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id} @@ -195,29 +240,14 @@ class AreaRegistry: if value is not UNDEFINED and value != getattr(old, attr_name): new_values[attr_name] = value - normalized_name = None - if name is not UNDEFINED and name != old.name: - normalized_name = normalize_area_name(name) - - if normalized_name != old.normalized_name and self.async_get_area_by_name( - name - ): - raise ValueError( - f"The name {name} ({normalized_name}) is already in use" - ) - new_values["name"] = name - new_values["normalized_name"] = normalized_name + new_values["normalized_name"] = normalize_area_name(name) if not new_values: return old new = self.areas[area_id] = dataclasses.replace(old, **new_values) # type: ignore[arg-type] - if normalized_name is not None: - self._normalized_name_area_idx[ - normalized_name - ] = self._normalized_name_area_idx.pop(old.normalized_name) self.async_schedule_save() return new @@ -226,7 +256,7 @@ class AreaRegistry: """Load the area registry.""" data = await self._store.async_load() - areas: MutableMapping[str, AreaEntry] = OrderedDict() + areas = AreaRegistryItems() if data is not None: for area in data["areas"]: @@ -239,9 +269,9 @@ class AreaRegistry: normalized_name=normalized_name, picture=area["picture"], ) - self._normalized_name_area_idx[normalized_name] = area["id"] self.areas = areas + self._area_data = areas.data @callback def async_schedule_save(self) -> None: