mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Store area registry entries in a UserDict (#108656)
* Store area registry entries in a UserDict * Address review comments
This commit is contained in:
parent
2eea658fd8
commit
329eca4918
@ -1,8 +1,8 @@
|
||||
"""Provide a way to connect devices to one physical location."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
from collections import UserDict
|
||||
from collections.abc import Iterable, ValuesView
|
||||
import dataclasses
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
@ -39,6 +39,52 @@ class AreaEntry:
|
||||
picture: str | None
|
||||
|
||||
|
||||
class AreaRegistryItems(UserDict[str, AreaEntry]):
|
||||
"""Container for area registry items, maps area id -> entry.
|
||||
|
||||
Maintains an additional index:
|
||||
- normalized name -> entry
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the container."""
|
||||
super().__init__()
|
||||
self._normalized_names: dict[str, AreaEntry] = {}
|
||||
|
||||
def values(self) -> ValuesView[AreaEntry]:
|
||||
"""Return the underlying values to avoid __iter__ overhead."""
|
||||
return self.data.values()
|
||||
|
||||
def __setitem__(self, key: str, entry: AreaEntry) -> None:
|
||||
"""Add an item."""
|
||||
data = self.data
|
||||
normalized_name = normalize_area_name(entry.name)
|
||||
|
||||
if key in data:
|
||||
old_entry = data[key]
|
||||
if (
|
||||
normalized_name != old_entry.normalized_name
|
||||
and normalized_name in self._normalized_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"The name {entry.name} ({normalized_name}) is already in use"
|
||||
)
|
||||
del self._normalized_names[old_entry.normalized_name]
|
||||
data[key] = entry
|
||||
self._normalized_names[normalized_name] = entry
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Remove an item."""
|
||||
entry = self[key]
|
||||
normalized_name = normalize_area_name(entry.name)
|
||||
del self._normalized_names[normalized_name]
|
||||
super().__delitem__(key)
|
||||
|
||||
def get_area_by_name(self, name: str) -> AreaEntry | None:
|
||||
"""Get area by name."""
|
||||
return self._normalized_names.get(normalize_area_name(name))
|
||||
|
||||
|
||||
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||
"""Store area registry data."""
|
||||
|
||||
@ -69,10 +115,12 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||
class AreaRegistry:
|
||||
"""Class to hold a registry of areas."""
|
||||
|
||||
areas: AreaRegistryItems
|
||||
_area_data: dict[str, AreaEntry]
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the area registry."""
|
||||
self.hass = hass
|
||||
self.areas: MutableMapping[str, AreaEntry] = {}
|
||||
self._store = AreaRegistryStore(
|
||||
hass,
|
||||
STORAGE_VERSION_MAJOR,
|
||||
@ -80,20 +128,20 @@ class AreaRegistry:
|
||||
atomic_writes=True,
|
||||
minor_version=STORAGE_VERSION_MINOR,
|
||||
)
|
||||
self._normalized_name_area_idx: dict[str, str] = {}
|
||||
|
||||
@callback
|
||||
def async_get_area(self, area_id: str) -> AreaEntry | None:
|
||||
"""Get area by id."""
|
||||
return self.areas.get(area_id)
|
||||
"""Get area by id.
|
||||
|
||||
We retrieve the DeviceEntry from the underlying dict to avoid
|
||||
the overhead of the UserDict __getitem__.
|
||||
"""
|
||||
return self._area_data.get(area_id)
|
||||
|
||||
@callback
|
||||
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
|
||||
"""Get area by name."""
|
||||
normalized_name = normalize_area_name(name)
|
||||
if normalized_name not in self._normalized_name_area_idx:
|
||||
return None
|
||||
return self.areas[self._normalized_name_area_idx[normalized_name]]
|
||||
return self.areas.get_area_by_name(name)
|
||||
|
||||
@callback
|
||||
def async_list_areas(self) -> Iterable[AreaEntry]:
|
||||
@ -131,7 +179,6 @@ class AreaRegistry:
|
||||
)
|
||||
assert area.id is not None
|
||||
self.areas[area.id] = area
|
||||
self._normalized_name_area_idx[normalized_name] = area.id
|
||||
self.async_schedule_save()
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id}
|
||||
@ -141,14 +188,12 @@ class AreaRegistry:
|
||||
@callback
|
||||
def async_delete(self, area_id: str) -> None:
|
||||
"""Delete area."""
|
||||
area = self.areas[area_id]
|
||||
device_registry = dr.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry.async_clear_area_id(area_id)
|
||||
entity_registry.async_clear_area_id(area_id)
|
||||
|
||||
del self.areas[area_id]
|
||||
del self._normalized_name_area_idx[area.normalized_name]
|
||||
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id}
|
||||
@ -195,29 +240,14 @@ class AreaRegistry:
|
||||
if value is not UNDEFINED and value != getattr(old, attr_name):
|
||||
new_values[attr_name] = value
|
||||
|
||||
normalized_name = None
|
||||
|
||||
if name is not UNDEFINED and name != old.name:
|
||||
normalized_name = normalize_area_name(name)
|
||||
|
||||
if normalized_name != old.normalized_name and self.async_get_area_by_name(
|
||||
name
|
||||
):
|
||||
raise ValueError(
|
||||
f"The name {name} ({normalized_name}) is already in use"
|
||||
)
|
||||
|
||||
new_values["name"] = name
|
||||
new_values["normalized_name"] = normalized_name
|
||||
new_values["normalized_name"] = normalize_area_name(name)
|
||||
|
||||
if not new_values:
|
||||
return old
|
||||
|
||||
new = self.areas[area_id] = dataclasses.replace(old, **new_values) # type: ignore[arg-type]
|
||||
if normalized_name is not None:
|
||||
self._normalized_name_area_idx[
|
||||
normalized_name
|
||||
] = self._normalized_name_area_idx.pop(old.normalized_name)
|
||||
|
||||
self.async_schedule_save()
|
||||
return new
|
||||
@ -226,7 +256,7 @@ class AreaRegistry:
|
||||
"""Load the area registry."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
areas: MutableMapping[str, AreaEntry] = OrderedDict()
|
||||
areas = AreaRegistryItems()
|
||||
|
||||
if data is not None:
|
||||
for area in data["areas"]:
|
||||
@ -239,9 +269,9 @@ class AreaRegistry:
|
||||
normalized_name=normalized_name,
|
||||
picture=area["picture"],
|
||||
)
|
||||
self._normalized_name_area_idx[normalized_name] = area["id"]
|
||||
|
||||
self.areas = areas
|
||||
self._area_data = areas.data
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user