Get area and floor by alias (#126150)

* Add possibility to get area by alias

* Add ability to get floor by alias

* Moved alias lookup to separate function, adjusted templates.

* Changed registry to return all areas/floors with given alias

* Use normalize_name from normalized_name_base_registry
This commit is contained in:
Andrii Mitnovych 2025-03-27 08:02:47 -07:00 committed by GitHub
parent c30f17f592
commit dea00fac3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 120 additions and 8 deletions

View File

@ -20,6 +20,7 @@ from .json import json_bytes, json_fragment
from .normalized_name_base_registry import ( from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry, NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems, NormalizedNameBaseRegistryItems,
normalize_name,
) )
from .registry import BaseRegistry, RegistryIndexType from .registry import BaseRegistry, RegistryIndexType
from .singleton import singleton from .singleton import singleton
@ -169,6 +170,7 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
super().__init__() super().__init__()
self._labels_index: RegistryIndexType = defaultdict(dict) self._labels_index: RegistryIndexType = defaultdict(dict)
self._floors_index: RegistryIndexType = defaultdict(dict) self._floors_index: RegistryIndexType = defaultdict(dict)
self._aliases_index: RegistryIndexType = defaultdict(dict)
def _index_entry(self, key: str, entry: AreaEntry) -> None: def _index_entry(self, key: str, entry: AreaEntry) -> None:
"""Index an entry.""" """Index an entry."""
@ -177,6 +179,9 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
self._floors_index[entry.floor_id][key] = True self._floors_index[entry.floor_id][key] = True
for label in entry.labels: for label in entry.labels:
self._labels_index[label][key] = True self._labels_index[label][key] = True
for alias in entry.aliases:
normalized_alias = normalize_name(alias)
self._aliases_index[normalized_alias][key] = True
def _unindex_entry( def _unindex_entry(
self, key: str, replacement_entry: AreaEntry | None = None self, key: str, replacement_entry: AreaEntry | None = None
@ -184,6 +189,10 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
# always call base class before other indices # always call base class before other indices
super()._unindex_entry(key, replacement_entry) super()._unindex_entry(key, replacement_entry)
entry = self.data[key] entry = self.data[key]
if aliases := entry.aliases:
for alias in aliases:
normalized_alias = normalize_name(alias)
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
if labels := entry.labels: if labels := entry.labels:
for label in labels: for label in labels:
self._unindex_entry_value(key, label, self._labels_index) self._unindex_entry_value(key, label, self._labels_index)
@ -200,6 +209,12 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
data = self.data data = self.data
return [data[key] for key in self._floors_index.get(floor, ())] return [data[key] for key in self._floors_index.get(floor, ())]
def get_areas_for_alias(self, alias: str) -> list[AreaEntry]:
"""Get areas for alias."""
data = self.data
normalized_alias = normalize_name(alias)
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]): class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Class to hold a registry of areas.""" """Class to hold a registry of areas."""
@ -232,6 +247,11 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Get area by name.""" """Get area by name."""
return self.areas.get_by_name(name) return self.areas.get_by_name(name)
@callback
def async_get_areas_by_alias(self, alias: str) -> list[AreaEntry]:
"""Get areas by alias."""
return self.areas.get_areas_for_alias(alias)
@callback @callback
def async_list_areas(self) -> Iterable[AreaEntry]: def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas.""" """Get all areas."""

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
@ -16,8 +17,9 @@ from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import ( from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry, NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems, NormalizedNameBaseRegistryItems,
normalize_name,
) )
from .registry import BaseRegistry from .registry import BaseRegistry, RegistryIndexType
from .singleton import singleton from .singleton import singleton
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -92,10 +94,43 @@ class FloorRegistryStore(Store[FloorRegistryStoreData]):
return old_data # type: ignore[return-value] return old_data # type: ignore[return-value]
class FloorRegistryItems(NormalizedNameBaseRegistryItems[FloorEntry]):
"""Class to hold floor registry items."""
def __init__(self) -> None:
"""Initialize the floor registry items."""
super().__init__()
self._aliases_index: RegistryIndexType = defaultdict(dict)
def _index_entry(self, key: str, entry: FloorEntry) -> None:
"""Index an entry."""
super()._index_entry(key, entry)
for alias in entry.aliases:
normalized_alias = normalize_name(alias)
self._aliases_index[normalized_alias][key] = True
def _unindex_entry(
self, key: str, replacement_entry: FloorEntry | None = None
) -> None:
# always call base class before other indices
super()._unindex_entry(key, replacement_entry)
entry = self.data[key]
if aliases := entry.aliases:
for alias in aliases:
normalized_alias = normalize_name(alias)
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
def get_floors_for_alias(self, alias: str) -> list[FloorEntry]:
"""Get floors for alias."""
data = self.data
normalized_alias = normalize_name(alias)
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]): class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Class to hold a registry of floors.""" """Class to hold a registry of floors."""
floors: NormalizedNameBaseRegistryItems[FloorEntry] floors: FloorRegistryItems
_floor_data: dict[str, FloorEntry] _floor_data: dict[str, FloorEntry]
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -123,6 +158,11 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Get floor by name.""" """Get floor by name."""
return self.floors.get_by_name(name) return self.floors.get_by_name(name)
@callback
def async_get_floors_by_alias(self, alias: str) -> list[FloorEntry]:
"""Get floors by alias."""
return self.floors.get_floors_for_alias(alias)
@callback @callback
def async_list_floors(self) -> Iterable[FloorEntry]: def async_list_floors(self) -> Iterable[FloorEntry]:
"""Get all floors.""" """Get all floors."""
@ -226,7 +266,7 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the floor registry.""" """Load the floor registry."""
data = await self._store.async_load() data = await self._store.async_load()
floors = NormalizedNameBaseRegistryItems[FloorEntry]() floors = FloorRegistryItems()
if data is not None: if data is not None:
for floor in data["floors"]: for floor in data["floors"]:

View File

@ -1478,10 +1478,14 @@ def floors(hass: HomeAssistant) -> Iterable[str | None]:
def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None: def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
"""Get the floor ID from a floor name.""" """Get the floor ID from a floor or area name, alias, device id, or entity id."""
floor_registry = fr.async_get(hass) floor_registry = fr.async_get(hass)
if floor := floor_registry.async_get_floor_by_name(str(lookup_value)): lookup_str = str(lookup_value)
if floor := floor_registry.async_get_floor_by_name(lookup_str):
return floor.floor_id return floor.floor_id
floors_list = floor_registry.async_get_floors_by_alias(lookup_str)
if floors_list:
return floors_list[0].floor_id
if aid := area_id(hass, lookup_value): if aid := area_id(hass, lookup_value):
area_reg = area_registry.async_get(hass) area_reg = area_registry.async_get(hass)
@ -1541,10 +1545,14 @@ def areas(hass: HomeAssistant) -> Iterable[str | None]:
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None: def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
"""Get the area ID from an area name, device id, or entity id.""" """Get the area ID from an area name, alias, device id, or entity id."""
area_reg = area_registry.async_get(hass) area_reg = area_registry.async_get(hass)
if area := area_reg.async_get_area_by_name(str(lookup_value)): lookup_str = str(lookup_value)
if area := area_reg.async_get_area_by_name(lookup_str):
return area.id return area.id
areas_list = area_reg.async_get_areas_by_alias(lookup_str)
if areas_list:
return areas_list[0].id
ent_reg = entity_registry.async_get(hass) ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)

View File

@ -494,6 +494,29 @@ async def test_async_get_area_by_name(area_registry: ar.AreaRegistry) -> None:
assert area_registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1" assert area_registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1"
async def test_async_get_areas_by_alias(
area_registry: ar.AreaRegistry,
) -> None:
"""Make sure we can get the areas by alias."""
area1 = area_registry.async_create("Mock1", aliases=("alias_1", "alias_2"))
area2 = area_registry.async_create("Mock2", aliases=("alias_1", "alias_3"))
assert len(area_registry.areas) == 2
alias1_list = area_registry.async_get_areas_by_alias("A l i a s_1")
alias2_list = area_registry.async_get_areas_by_alias("A l i a s_2")
alias3_list = area_registry.async_get_areas_by_alias("A l i a s_3")
assert len(alias1_list) == 2
assert len(alias2_list) == 1
assert len(alias3_list) == 1
assert area1 in alias1_list
assert area1 in alias2_list
assert area2 in alias1_list
assert area2 in alias3_list
async def test_async_get_area_by_name_not_found(area_registry: ar.AreaRegistry) -> None: async def test_async_get_area_by_name_not_found(area_registry: ar.AreaRegistry) -> None:
"""Make sure we return None for non-existent areas.""" """Make sure we return None for non-existent areas."""
area_registry.async_create("Mock1") area_registry.async_create("Mock1")

View File

@ -327,7 +327,7 @@ async def test_loading_floors_from_storage(
assert len(registry.floors) == 1 assert len(registry.floors) == 1
async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None: async def test_getting_floor_by_name(floor_registry: fr.FloorRegistry) -> None:
"""Make sure we can get the floors by name.""" """Make sure we can get the floors by name."""
floor = floor_registry.async_create("First floor") floor = floor_registry.async_create("First floor")
floor2 = floor_registry.async_get_floor_by_name("first floor") floor2 = floor_registry.async_get_floor_by_name("first floor")
@ -341,6 +341,27 @@ async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None:
assert get_floor == floor assert get_floor == floor
async def test_async_get_floors_by_alias(
floor_registry: fr.FloorRegistry,
) -> None:
"""Make sure we can get the floors by alias."""
floor1 = floor_registry.async_create("First floor", aliases=("alias_1", "alias_2"))
floor2 = floor_registry.async_create("Second floor", aliases=("alias_1", "alias_3"))
alias1_list = floor_registry.async_get_floors_by_alias("A l i a s_1")
alias2_list = floor_registry.async_get_floors_by_alias("A l i a s_2")
alias3_list = floor_registry.async_get_floors_by_alias("A l i a s_3")
assert len(alias1_list) == 2
assert len(alias2_list) == 1
assert len(alias3_list) == 1
assert floor1 in alias1_list
assert floor1 in alias2_list
assert floor2 in alias1_list
assert floor2 in alias3_list
async def test_async_get_floor_by_name_not_found( async def test_async_get_floor_by_name_not_found(
floor_registry: fr.FloorRegistry, floor_registry: fr.FloorRegistry,
) -> None: ) -> None: