Store area registry entries in a UserDict (#108656)

* Store area registry entries in a UserDict

* Address review comments
This commit is contained in:
Erik Montnemery 2024-01-23 08:14:28 +01:00 committed by GitHub
parent 2eea658fd8
commit 329eca4918
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,8 @@
"""Provide a way to connect devices to one physical location."""
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Iterable, MutableMapping
from collections import UserDict
from collections.abc import Iterable, ValuesView
import dataclasses
from typing import Any, Literal, TypedDict, cast
@ -39,6 +39,52 @@ class AreaEntry:
picture: str | None
class AreaRegistryItems(UserDict[str, AreaEntry]):
"""Container for area registry items, maps area id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, AreaEntry] = {}
def values(self) -> ValuesView[AreaEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: AreaEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_area_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_area_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
return self._normalized_names.get(normalize_area_name(name))
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
"""Store area registry data."""
@ -69,10 +115,12 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
class AreaRegistry:
"""Class to hold a registry of areas."""
areas: AreaRegistryItems
_area_data: dict[str, AreaEntry]
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the area registry."""
self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {}
self._store = AreaRegistryStore(
hass,
STORAGE_VERSION_MAJOR,
@ -80,20 +128,20 @@ class AreaRegistry:
atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
)
self._normalized_name_area_idx: dict[str, str] = {}
@callback
def async_get_area(self, area_id: str) -> AreaEntry | None:
"""Get area by id."""
return self.areas.get(area_id)
"""Get area by id.
We retrieve the DeviceEntry from the underlying dict to avoid
the overhead of the UserDict __getitem__.
"""
return self._area_data.get(area_id)
@callback
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
normalized_name = normalize_area_name(name)
if normalized_name not in self._normalized_name_area_idx:
return None
return self.areas[self._normalized_name_area_idx[normalized_name]]
return self.areas.get_area_by_name(name)
@callback
def async_list_areas(self) -> Iterable[AreaEntry]:
@ -131,7 +179,6 @@ class AreaRegistry:
)
assert area.id is not None
self.areas[area.id] = area
self._normalized_name_area_idx[normalized_name] = area.id
self.async_schedule_save()
self.hass.bus.async_fire(
EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id}
@ -141,14 +188,12 @@ class AreaRegistry:
@callback
def async_delete(self, area_id: str) -> None:
"""Delete area."""
area = self.areas[area_id]
device_registry = dr.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry.async_clear_area_id(area_id)
entity_registry.async_clear_area_id(area_id)
del self.areas[area_id]
del self._normalized_name_area_idx[area.normalized_name]
self.hass.bus.async_fire(
EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id}
@ -195,29 +240,14 @@ class AreaRegistry:
if value is not UNDEFINED and value != getattr(old, attr_name):
new_values[attr_name] = value
normalized_name = None
if name is not UNDEFINED and name != old.name:
normalized_name = normalize_area_name(name)
if normalized_name != old.normalized_name and self.async_get_area_by_name(
name
):
raise ValueError(
f"The name {name} ({normalized_name}) is already in use"
)
new_values["name"] = name
new_values["normalized_name"] = normalized_name
new_values["normalized_name"] = normalize_area_name(name)
if not new_values:
return old
new = self.areas[area_id] = dataclasses.replace(old, **new_values) # type: ignore[arg-type]
if normalized_name is not None:
self._normalized_name_area_idx[
normalized_name
] = self._normalized_name_area_idx.pop(old.normalized_name)
self.async_schedule_save()
return new
@ -226,7 +256,7 @@ class AreaRegistry:
"""Load the area registry."""
data = await self._store.async_load()
areas: MutableMapping[str, AreaEntry] = OrderedDict()
areas = AreaRegistryItems()
if data is not None:
for area in data["areas"]:
@ -239,9 +269,9 @@ class AreaRegistry:
normalized_name=normalized_name,
picture=area["picture"],
)
self._normalized_name_area_idx[normalized_name] = area["id"]
self.areas = areas
self._area_data = areas.data
@callback
def async_schedule_save(self) -> None: