mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Store area registry entries in a UserDict (#108656)
* Store area registry entries in a UserDict * Address review comments
This commit is contained in:
parent
2eea658fd8
commit
329eca4918
@ -1,8 +1,8 @@
|
|||||||
"""Provide a way to connect devices to one physical location."""
|
"""Provide a way to connect devices to one physical location."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import UserDict
|
||||||
from collections.abc import Iterable, MutableMapping
|
from collections.abc import Iterable, ValuesView
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Literal, TypedDict, cast
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
@ -39,6 +39,52 @@ class AreaEntry:
|
|||||||
picture: str | None
|
picture: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class AreaRegistryItems(UserDict[str, AreaEntry]):
|
||||||
|
"""Container for area registry items, maps area id -> entry.
|
||||||
|
|
||||||
|
Maintains an additional index:
|
||||||
|
- normalized name -> entry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the container."""
|
||||||
|
super().__init__()
|
||||||
|
self._normalized_names: dict[str, AreaEntry] = {}
|
||||||
|
|
||||||
|
def values(self) -> ValuesView[AreaEntry]:
|
||||||
|
"""Return the underlying values to avoid __iter__ overhead."""
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, entry: AreaEntry) -> None:
|
||||||
|
"""Add an item."""
|
||||||
|
data = self.data
|
||||||
|
normalized_name = normalize_area_name(entry.name)
|
||||||
|
|
||||||
|
if key in data:
|
||||||
|
old_entry = data[key]
|
||||||
|
if (
|
||||||
|
normalized_name != old_entry.normalized_name
|
||||||
|
and normalized_name in self._normalized_names
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The name {entry.name} ({normalized_name}) is already in use"
|
||||||
|
)
|
||||||
|
del self._normalized_names[old_entry.normalized_name]
|
||||||
|
data[key] = entry
|
||||||
|
self._normalized_names[normalized_name] = entry
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""Remove an item."""
|
||||||
|
entry = self[key]
|
||||||
|
normalized_name = normalize_area_name(entry.name)
|
||||||
|
del self._normalized_names[normalized_name]
|
||||||
|
super().__delitem__(key)
|
||||||
|
|
||||||
|
def get_area_by_name(self, name: str) -> AreaEntry | None:
|
||||||
|
"""Get area by name."""
|
||||||
|
return self._normalized_names.get(normalize_area_name(name))
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
"""Store area registry data."""
|
"""Store area registry data."""
|
||||||
|
|
||||||
@ -69,10 +115,12 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
|||||||
class AreaRegistry:
|
class AreaRegistry:
|
||||||
"""Class to hold a registry of areas."""
|
"""Class to hold a registry of areas."""
|
||||||
|
|
||||||
|
areas: AreaRegistryItems
|
||||||
|
_area_data: dict[str, AreaEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the area registry."""
|
"""Initialize the area registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.areas: MutableMapping[str, AreaEntry] = {}
|
|
||||||
self._store = AreaRegistryStore(
|
self._store = AreaRegistryStore(
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
@ -80,20 +128,20 @@ class AreaRegistry:
|
|||||||
atomic_writes=True,
|
atomic_writes=True,
|
||||||
minor_version=STORAGE_VERSION_MINOR,
|
minor_version=STORAGE_VERSION_MINOR,
|
||||||
)
|
)
|
||||||
self._normalized_name_area_idx: dict[str, str] = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_area(self, area_id: str) -> AreaEntry | None:
|
def async_get_area(self, area_id: str) -> AreaEntry | None:
|
||||||
"""Get area by id."""
|
"""Get area by id.
|
||||||
return self.areas.get(area_id)
|
|
||||||
|
We retrieve the DeviceEntry from the underlying dict to avoid
|
||||||
|
the overhead of the UserDict __getitem__.
|
||||||
|
"""
|
||||||
|
return self._area_data.get(area_id)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
|
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
|
||||||
"""Get area by name."""
|
"""Get area by name."""
|
||||||
normalized_name = normalize_area_name(name)
|
return self.areas.get_area_by_name(name)
|
||||||
if normalized_name not in self._normalized_name_area_idx:
|
|
||||||
return None
|
|
||||||
return self.areas[self._normalized_name_area_idx[normalized_name]]
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_areas(self) -> Iterable[AreaEntry]:
|
def async_list_areas(self) -> Iterable[AreaEntry]:
|
||||||
@ -131,7 +179,6 @@ class AreaRegistry:
|
|||||||
)
|
)
|
||||||
assert area.id is not None
|
assert area.id is not None
|
||||||
self.areas[area.id] = area
|
self.areas[area.id] = area
|
||||||
self._normalized_name_area_idx[normalized_name] = area.id
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id}
|
EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id}
|
||||||
@ -141,14 +188,12 @@ class AreaRegistry:
|
|||||||
@callback
|
@callback
|
||||||
def async_delete(self, area_id: str) -> None:
|
def async_delete(self, area_id: str) -> None:
|
||||||
"""Delete area."""
|
"""Delete area."""
|
||||||
area = self.areas[area_id]
|
|
||||||
device_registry = dr.async_get(self.hass)
|
device_registry = dr.async_get(self.hass)
|
||||||
entity_registry = er.async_get(self.hass)
|
entity_registry = er.async_get(self.hass)
|
||||||
device_registry.async_clear_area_id(area_id)
|
device_registry.async_clear_area_id(area_id)
|
||||||
entity_registry.async_clear_area_id(area_id)
|
entity_registry.async_clear_area_id(area_id)
|
||||||
|
|
||||||
del self.areas[area_id]
|
del self.areas[area_id]
|
||||||
del self._normalized_name_area_idx[area.normalized_name]
|
|
||||||
|
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id}
|
EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id}
|
||||||
@ -195,29 +240,14 @@ class AreaRegistry:
|
|||||||
if value is not UNDEFINED and value != getattr(old, attr_name):
|
if value is not UNDEFINED and value != getattr(old, attr_name):
|
||||||
new_values[attr_name] = value
|
new_values[attr_name] = value
|
||||||
|
|
||||||
normalized_name = None
|
|
||||||
|
|
||||||
if name is not UNDEFINED and name != old.name:
|
if name is not UNDEFINED and name != old.name:
|
||||||
normalized_name = normalize_area_name(name)
|
|
||||||
|
|
||||||
if normalized_name != old.normalized_name and self.async_get_area_by_name(
|
|
||||||
name
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The name {name} ({normalized_name}) is already in use"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_values["name"] = name
|
new_values["name"] = name
|
||||||
new_values["normalized_name"] = normalized_name
|
new_values["normalized_name"] = normalize_area_name(name)
|
||||||
|
|
||||||
if not new_values:
|
if not new_values:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
new = self.areas[area_id] = dataclasses.replace(old, **new_values) # type: ignore[arg-type]
|
new = self.areas[area_id] = dataclasses.replace(old, **new_values) # type: ignore[arg-type]
|
||||||
if normalized_name is not None:
|
|
||||||
self._normalized_name_area_idx[
|
|
||||||
normalized_name
|
|
||||||
] = self._normalized_name_area_idx.pop(old.normalized_name)
|
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
return new
|
return new
|
||||||
@ -226,7 +256,7 @@ class AreaRegistry:
|
|||||||
"""Load the area registry."""
|
"""Load the area registry."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
|
||||||
areas: MutableMapping[str, AreaEntry] = OrderedDict()
|
areas = AreaRegistryItems()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for area in data["areas"]:
|
for area in data["areas"]:
|
||||||
@ -239,9 +269,9 @@ class AreaRegistry:
|
|||||||
normalized_name=normalized_name,
|
normalized_name=normalized_name,
|
||||||
picture=area["picture"],
|
picture=area["picture"],
|
||||||
)
|
)
|
||||||
self._normalized_name_area_idx[normalized_name] = area["id"]
|
|
||||||
|
|
||||||
self.areas = areas
|
self.areas = areas
|
||||||
|
self._area_data = areas.data
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_save(self) -> None:
|
def async_schedule_save(self) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user