diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 38c554ffda3..f1731f43473 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -1,8 +1,7 @@ """Provide a way to connect devices to one physical location.""" from __future__ import annotations -from collections import UserDict -from collections.abc import Iterable, ValuesView +from collections.abc import Iterable import dataclasses from typing import Any, Literal, TypedDict, cast @@ -10,6 +9,11 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.util import slugify from . import device_registry as dr, entity_registry as er +from .normalized_name_base_registry import ( + NormalizedNameBaseRegistryEntry, + NormalizedNameBaseRegistryItems, + normalize_name, +) from .storage import Store from .typing import UNDEFINED, UndefinedType @@ -29,7 +33,7 @@ class EventAreaRegistryUpdatedData(TypedDict): @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) -class AreaEntry: +class AreaEntry(NormalizedNameBaseRegistryEntry): """Area Registry Entry.""" aliases: set[str] @@ -37,57 +41,9 @@ class AreaEntry: icon: str | None id: str labels: set[str] = dataclasses.field(default_factory=set) - name: str - normalized_name: str picture: str | None -class AreaRegistryItems(UserDict[str, AreaEntry]): - """Container for area registry items, maps area id -> entry. - - Maintains an additional index: - - normalized name -> entry - """ - - def __init__(self) -> None: - """Initialize the container.""" - super().__init__() - self._normalized_names: dict[str, AreaEntry] = {} - - def values(self) -> ValuesView[AreaEntry]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() - - def __setitem__(self, key: str, entry: AreaEntry) -> None: - """Add an item.""" - data = self.data - normalized_name = normalize_area_name(entry.name) - - if key in data: - old_entry = data[key] - if ( - normalized_name != old_entry.normalized_name - and normalized_name in self._normalized_names - ): - raise ValueError( - f"The name {entry.name} ({normalized_name}) is already in use" - ) - del self._normalized_names[old_entry.normalized_name] - data[key] = entry - self._normalized_names[normalized_name] = entry - - def __delitem__(self, key: str) -> None: - """Remove an item.""" - entry = self[key] - normalized_name = normalize_area_name(entry.name) - del self._normalized_names[normalized_name] - super().__delitem__(key) - - def get_area_by_name(self, name: str) -> AreaEntry | None: - """Get area by name.""" - return self._normalized_names.get(normalize_area_name(name)) - - class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): """Store area registry data.""" @@ -133,7 +89,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): class AreaRegistry: """Class to hold a registry of areas.""" - areas: AreaRegistryItems + areas: NormalizedNameBaseRegistryItems[AreaEntry] _area_data: dict[str, AreaEntry] def __init__(self, hass: HomeAssistant) -> None: @@ -159,7 +115,7 @@ class AreaRegistry: @callback def async_get_area_by_name(self, name: str) -> AreaEntry | None: """Get area by name.""" - return self.areas.get_area_by_name(name) + return self.areas.get_by_name(name) @callback def async_list_areas(self) -> Iterable[AreaEntry]: @@ -185,7 +141,7 @@ class AreaRegistry: picture: str | None = None, ) -> AreaEntry: """Create a new area.""" - normalized_name = normalize_area_name(name) + normalized_name = normalize_name(name) if self.async_get_area_by_name(name): raise ValueError(f"The name {name} ({normalized_name}) is already in use") @@ -281,7 +237,7 @@ class AreaRegistry: if name is not UNDEFINED and name != old.name: new_values["name"] = name - new_values["normalized_name"] = normalize_area_name(name) + new_values["normalized_name"] = normalize_name(name) if not new_values: return old @@ -297,12 +253,12 @@ class AreaRegistry: data = await self._store.async_load() - areas = AreaRegistryItems() + areas = NormalizedNameBaseRegistryItems[AreaEntry]() if data is not None: for area in data["areas"]: assert area["name"] is not None and area["id"] is not None - normalized_name = normalize_area_name(area["name"]) + normalized_name = normalize_name(area["name"]) areas[area["id"]] = AreaEntry( aliases=set(area["aliases"]), floor_id=area["floor_id"], @@ -421,8 +377,3 @@ def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaE def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]: """Return entries that match a label.""" return [area for area in registry.areas.values() if label_id in area.labels] - - -def normalize_area_name(area_name: str) -> str: - """Normalize an area name by removing whitespace and case folding.""" - return area_name.casefold().replace(" ", "") diff --git a/homeassistant/helpers/floor_registry.py b/homeassistant/helpers/floor_registry.py index 1149bbd1729..978471d7cd2 100644 --- a/homeassistant/helpers/floor_registry.py +++ b/homeassistant/helpers/floor_registry.py @@ -1,8 +1,7 @@ """Provide a way to assign areas to floors in one's home.""" from __future__ import annotations -from collections import UserDict -from collections.abc import Iterable, ValuesView +from collections.abc import Iterable import dataclasses from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, TypedDict, cast @@ -10,6 +9,11 @@ from typing import TYPE_CHECKING, Literal, TypedDict, cast from homeassistant.core import HomeAssistant, callback from homeassistant.util import slugify +from .normalized_name_base_registry import ( + NormalizedNameBaseRegistryEntry, + NormalizedNameBaseRegistryItems, + normalize_name, +) from .storage import Store from .typing import UNDEFINED, EventType, UndefinedType @@ -31,67 +35,19 @@ EventFloorRegistryUpdated = EventType[EventFloorRegistryUpdatedData] @dataclass(slots=True, kw_only=True, frozen=True) -class FloorEntry: +class FloorEntry(NormalizedNameBaseRegistryEntry): """Floor registry entry.""" aliases: set[str] floor_id: str icon: str | None = None level: int = 0 - name: str - normalized_name: str - - -class FloorRegistryItems(UserDict[str, FloorEntry]): - """Container for floor registry items, maps floor id -> entry. - - Maintains an additional index: - - normalized name -> entry - """ - - def __init__(self) -> None: - """Initialize the container.""" - super().__init__() - self._normalized_names: dict[str, FloorEntry] = {} - - def values(self) -> ValuesView[FloorEntry]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() - - def __setitem__(self, key: str, entry: FloorEntry) -> None: - """Add an item.""" - data = self.data - normalized_name = _normalize_floor_name(entry.name) - - if key in data: - old_entry = data[key] - if ( - normalized_name != old_entry.normalized_name - and normalized_name in self._normalized_names - ): - raise ValueError( - f"The name {entry.name} ({normalized_name}) is already in use" - ) - del self._normalized_names[old_entry.normalized_name] - data[key] = entry - self._normalized_names[normalized_name] = entry - - def __delitem__(self, key: str) -> None: - """Remove an item.""" - entry = self[key] - normalized_name = _normalize_floor_name(entry.name) - del self._normalized_names[normalized_name] - super().__delitem__(key) - - def get_floor_by_name(self, name: str) -> FloorEntry | None: - """Get floor by name.""" - return self._normalized_names.get(_normalize_floor_name(name)) class FloorRegistry: """Class to hold a registry of floors.""" - floors: FloorRegistryItems + floors: NormalizedNameBaseRegistryItems[FloorEntry] _floor_data: dict[str, FloorEntry] def __init__(self, hass: HomeAssistant) -> None: @@ -118,7 +74,7 @@ class FloorRegistry: @callback def async_get_floor_by_name(self, name: str) -> FloorEntry | None: """Get floor by name.""" - return self.floors.get_floor_by_name(name) + return self.floors.get_by_name(name) @callback def async_list_floors(self) -> Iterable[FloorEntry]: @@ -150,7 +106,7 @@ class FloorRegistry: f"The name {name} ({floor.normalized_name}) is already in use" ) - normalized_name = _normalize_floor_name(name) + normalized_name = normalize_name(name) floor = FloorEntry( aliases=aliases or set(), @@ -208,7 +164,7 @@ class FloorRegistry: } if name is not UNDEFINED and name != old.name: changes["name"] = name - changes["normalized_name"] = _normalize_floor_name(name) + changes["normalized_name"] = normalize_name(name) if not changes: return old @@ -229,7 +185,7 @@ class FloorRegistry: async def async_load(self) -> None: """Load the floor registry.""" data = await self._store.async_load() - floors = FloorRegistryItems() + floors = NormalizedNameBaseRegistryItems[FloorEntry]() if data is not None: for floor in data["floors"]: @@ -240,7 +196,7 @@ class FloorRegistry: assert isinstance(floor["name"], str) assert isinstance(floor["floor_id"], str) - normalized_name = _normalize_floor_name(floor["name"]) + normalized_name = normalize_name(floor["name"]) floors[floor["floor_id"]] = FloorEntry( aliases=set(floor["aliases"]), icon=floor["icon"], @@ -286,8 +242,3 @@ async def async_load(hass: HomeAssistant) -> None: assert DATA_REGISTRY not in hass.data hass.data[DATA_REGISTRY] = FloorRegistry(hass) await hass.data[DATA_REGISTRY].async_load() - - -def _normalize_floor_name(floor_name: str) -> str: - """Normalize a floor name by removing whitespace and case folding.""" - return floor_name.casefold().replace(" ", "") diff --git a/homeassistant/helpers/label_registry.py b/homeassistant/helpers/label_registry.py index 9c7f20a6515..ef3abc19d8c 100644 --- a/homeassistant/helpers/label_registry.py +++ b/homeassistant/helpers/label_registry.py @@ -1,8 +1,7 @@ """Provide a way to label and group anything.""" from __future__ import annotations -from collections import UserDict -from collections.abc import Iterable, ValuesView +from collections.abc import Iterable import dataclasses from dataclasses import dataclass from typing import Literal, TypedDict, cast @@ -10,6 +9,11 @@ from typing import Literal, TypedDict, cast from homeassistant.core import HomeAssistant, callback from homeassistant.util import slugify +from .normalized_name_base_registry import ( + NormalizedNameBaseRegistryEntry, + NormalizedNameBaseRegistryItems, + normalize_name, +) from .storage import Store from .typing import UNDEFINED, EventType, UndefinedType @@ -30,68 +34,20 @@ class EventLabelRegistryUpdatedData(TypedDict): EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData] -@dataclass(slots=True, frozen=True) -class LabelEntry: +@dataclass(slots=True, frozen=True, kw_only=True) +class LabelEntry(NormalizedNameBaseRegistryEntry): """Label Registry Entry.""" label_id: str - name: str - normalized_name: str description: str | None = None color: str | None = None icon: str | None = None -class LabelRegistryItems(UserDict[str, LabelEntry]): - """Container for label registry items, maps label id -> entry. - - Maintains an additional index: - - normalized name -> entry - """ - - def __init__(self) -> None: - """Initialize the container.""" - super().__init__() - self._normalized_names: dict[str, LabelEntry] = {} - - def values(self) -> ValuesView[LabelEntry]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() - - def __setitem__(self, key: str, entry: LabelEntry) -> None: - """Add an item.""" - data = self.data - normalized_name = _normalize_label_name(entry.name) - - if key in data: - old_entry = data[key] - if ( - normalized_name != old_entry.normalized_name - and normalized_name in self._normalized_names - ): - raise ValueError( - f"The name {entry.name} ({normalized_name}) is already in use" - ) - del self._normalized_names[old_entry.normalized_name] - data[key] = entry - self._normalized_names[normalized_name] = entry - - def __delitem__(self, key: str) -> None: - """Remove an item.""" - entry = self[key] - normalized_name = _normalize_label_name(entry.name) - del self._normalized_names[normalized_name] - super().__delitem__(key) - - def get_label_by_name(self, name: str) -> LabelEntry | None: - """Get label by name.""" - return self._normalized_names.get(_normalize_label_name(name)) - - class LabelRegistry: """Class to hold a registry of labels.""" - labels: LabelRegistryItems + labels: NormalizedNameBaseRegistryItems[LabelEntry] _label_data: dict[str, LabelEntry] def __init__(self, hass: HomeAssistant) -> None: @@ -116,7 +72,7 @@ class LabelRegistry: @callback def async_get_label_by_name(self, name: str) -> LabelEntry | None: """Get label by name.""" - return self.labels.get_label_by_name(name) + return self.labels.get_by_name(name) @callback def async_list_labels(self) -> Iterable[LabelEntry]: @@ -148,7 +104,7 @@ class LabelRegistry: f"The name {name} ({label.normalized_name}) is already in use" ) - normalized_name = _normalize_label_name(name) + normalized_name = normalize_name(name) label = LabelEntry( color=color, @@ -207,7 +163,7 @@ class LabelRegistry: if name is not UNDEFINED and name != old.name: changes["name"] = name - changes["normalized_name"] = _normalize_label_name(name) + changes["normalized_name"] = normalize_name(name) if not changes: return old @@ -228,7 +184,7 @@ class LabelRegistry: async def async_load(self) -> None: """Load the label registry.""" data = await self._store.async_load() - labels = LabelRegistryItems() + labels = NormalizedNameBaseRegistryItems[LabelEntry]() if data is not None: for label in data["labels"]: @@ -236,7 +192,7 @@ class LabelRegistry: if label["label_id"] is None or label["name"] is None: continue - normalized_name = _normalize_label_name(label["name"]) + normalized_name = normalize_name(label["name"]) labels[label["label_id"]] = LabelEntry( color=label["color"], description=label["description"], @@ -282,8 +238,3 @@ async def async_load(hass: HomeAssistant) -> None: assert DATA_REGISTRY not in hass.data hass.data[DATA_REGISTRY] = LabelRegistry(hass) await hass.data[DATA_REGISTRY].async_load() - - -def _normalize_label_name(label_name: str) -> str: - """Normalize a label name by removing whitespace and case folding.""" - return label_name.casefold().replace(" ", "") diff --git a/homeassistant/helpers/normalized_name_base_registry.py b/homeassistant/helpers/normalized_name_base_registry.py new file mode 100644 index 00000000000..13a4cb10312 --- /dev/null +++ b/homeassistant/helpers/normalized_name_base_registry.py @@ -0,0 +1,67 @@ +"""Provide a base class for registries that use a normalized name index.""" +from collections import UserDict +from collections.abc import ValuesView +from dataclasses import dataclass +from typing import TypeVar + + +@dataclass(slots=True, frozen=True, kw_only=True) +class NormalizedNameBaseRegistryEntry: + """Normalized Name Base Registry Entry.""" + + name: str + normalized_name: str + + +_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry) + + +def normalize_name(name: str) -> str: + """Normalize a name by removing whitespace and case folding.""" + return name.casefold().replace(" ", "") + + +class NormalizedNameBaseRegistryItems(UserDict[str, _VT]): + """Base container for normalized name registry items, maps key -> entry. + + Maintains an additional index: + - normalized name -> entry + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._normalized_names: dict[str, _VT] = {} + + def values(self) -> ValuesView[_VT]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + + def __setitem__(self, key: str, entry: _VT) -> None: + """Add an item.""" + data = self.data + normalized_name = normalize_name(entry.name) + + if key in data: + old_entry = data[key] + if ( + normalized_name != old_entry.normalized_name + and normalized_name in self._normalized_names + ): + raise ValueError( + f"The name {entry.name} ({normalized_name}) is already in use" + ) + del self._normalized_names[old_entry.normalized_name] + data[key] = entry + self._normalized_names[normalized_name] = entry + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + entry = self[key] + normalized_name = normalize_name(entry.name) + del self._normalized_names[normalized_name] + super().__delitem__(key) + + def get_by_name(self, name: str) -> _VT | None: + """Get entry by name.""" + return self._normalized_names.get(normalize_name(name)) diff --git a/tests/helpers/test_normalized_name_base_registry.py b/tests/helpers/test_normalized_name_base_registry.py new file mode 100644 index 00000000000..0b0e53abe83 --- /dev/null +++ b/tests/helpers/test_normalized_name_base_registry.py @@ -0,0 +1,67 @@ +"""Tests for the normalized name base registry helper.""" +import pytest + +from homeassistant.helpers.normalized_name_base_registry import ( + NormalizedNameBaseRegistryEntry, + NormalizedNameBaseRegistryItems, + normalize_name, +) + + +@pytest.fixture +def registry_items(): + """Fixture for registry items.""" + return NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry]() + + +def test_normalize_name(): + """Test normalize_name.""" + assert normalize_name("Hello World") == "helloworld" + assert normalize_name("HELLO WORLD") == "helloworld" + assert normalize_name(" Hello World ") == "helloworld" + + +def test_registry_items( + registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry], +): + """Test registry items.""" + entry = NormalizedNameBaseRegistryEntry( + name="Hello World", normalized_name="helloworld" + ) + registry_items["key"] = entry + assert registry_items["key"] == entry + assert list(registry_items.values()) == [entry] + assert registry_items.get_by_name("Hello World") == entry + + # test update entry + entry2 = NormalizedNameBaseRegistryEntry( + name="Hello World 2", normalized_name="helloworld2" + ) + registry_items["key"] = entry2 + assert registry_items["key"] == entry2 + assert list(registry_items.values()) == [entry2] + assert registry_items.get_by_name("Hello World 2") == entry2 + + # test delete entry + del registry_items["key"] + assert "key" not in registry_items + assert list(registry_items.values()) == [] + + +def test_key_already_in_use( + registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry], +): + """Test key already in use.""" + entry = NormalizedNameBaseRegistryEntry( + name="Hello World", normalized_name="helloworld" + ) + registry_items["key"] = entry + + # should raise ValueError if we update a + # key with a entry with the same normalized name + with pytest.raises(ValueError): + entry = NormalizedNameBaseRegistryEntry( + name="Hello World 2", normalized_name="helloworld2" + ) + registry_items["key2"] = entry + registry_items["key"] = entry