Get area and floor by alias (#126150)

* Add possibility to get area by alias

* Add ability to get floor by alias

* Moved alias lookup to separate function, adjusted templates.

* Changed registry to return all areas/floors with given alias

* Use normalize_name from normalized_name_base_registry
This commit is contained in:
Andrii Mitnovych 2025-03-27 08:02:47 -07:00 committed by GitHub
parent c30f17f592
commit dea00fac3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 120 additions and 8 deletions

View File

@ -20,6 +20,7 @@ from .json import json_bytes, json_fragment
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry, RegistryIndexType
from .singleton import singleton
@ -169,6 +170,7 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
super().__init__()
self._labels_index: RegistryIndexType = defaultdict(dict)
self._floors_index: RegistryIndexType = defaultdict(dict)
self._aliases_index: RegistryIndexType = defaultdict(dict)
def _index_entry(self, key: str, entry: AreaEntry) -> None:
"""Index an entry."""
@ -177,6 +179,9 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
self._floors_index[entry.floor_id][key] = True
for label in entry.labels:
self._labels_index[label][key] = True
for alias in entry.aliases:
normalized_alias = normalize_name(alias)
self._aliases_index[normalized_alias][key] = True
def _unindex_entry(
self, key: str, replacement_entry: AreaEntry | None = None
@ -184,6 +189,10 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
# always call base class before other indices
super()._unindex_entry(key, replacement_entry)
entry = self.data[key]
if aliases := entry.aliases:
for alias in aliases:
normalized_alias = normalize_name(alias)
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
@ -200,6 +209,12 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
data = self.data
return [data[key] for key in self._floors_index.get(floor, ())]
def get_areas_for_alias(self, alias: str) -> list[AreaEntry]:
"""Get areas for alias."""
data = self.data
normalized_alias = normalize_name(alias)
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Class to hold a registry of areas."""
@ -232,6 +247,11 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Get area by name."""
return self.areas.get_by_name(name)
@callback
def async_get_areas_by_alias(self, alias: str) -> list[AreaEntry]:
"""Get areas by alias."""
return self.areas.get_areas_for_alias(alias)
@callback
def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas."""

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
@ -16,8 +17,9 @@ from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry
from .registry import BaseRegistry, RegistryIndexType
from .singleton import singleton
from .storage import Store
from .typing import UNDEFINED, UndefinedType
@ -92,10 +94,43 @@ class FloorRegistryStore(Store[FloorRegistryStoreData]):
return old_data # type: ignore[return-value]
class FloorRegistryItems(NormalizedNameBaseRegistryItems[FloorEntry]):
"""Class to hold floor registry items."""
def __init__(self) -> None:
"""Initialize the floor registry items."""
super().__init__()
self._aliases_index: RegistryIndexType = defaultdict(dict)
def _index_entry(self, key: str, entry: FloorEntry) -> None:
"""Index an entry."""
super()._index_entry(key, entry)
for alias in entry.aliases:
normalized_alias = normalize_name(alias)
self._aliases_index[normalized_alias][key] = True
def _unindex_entry(
self, key: str, replacement_entry: FloorEntry | None = None
) -> None:
# always call base class before other indices
super()._unindex_entry(key, replacement_entry)
entry = self.data[key]
if aliases := entry.aliases:
for alias in aliases:
normalized_alias = normalize_name(alias)
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
def get_floors_for_alias(self, alias: str) -> list[FloorEntry]:
"""Get floors for alias."""
data = self.data
normalized_alias = normalize_name(alias)
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Class to hold a registry of floors."""
floors: NormalizedNameBaseRegistryItems[FloorEntry]
floors: FloorRegistryItems
_floor_data: dict[str, FloorEntry]
def __init__(self, hass: HomeAssistant) -> None:
@ -123,6 +158,11 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Get floor by name."""
return self.floors.get_by_name(name)
@callback
def async_get_floors_by_alias(self, alias: str) -> list[FloorEntry]:
"""Get floors by alias."""
return self.floors.get_floors_for_alias(alias)
@callback
def async_list_floors(self) -> Iterable[FloorEntry]:
"""Get all floors."""
@ -226,7 +266,7 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
async def async_load(self) -> None:
"""Load the floor registry."""
data = await self._store.async_load()
floors = NormalizedNameBaseRegistryItems[FloorEntry]()
floors = FloorRegistryItems()
if data is not None:
for floor in data["floors"]:

View File

@ -1478,10 +1478,14 @@ def floors(hass: HomeAssistant) -> Iterable[str | None]:
def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
"""Get the floor ID from a floor name."""
"""Get the floor ID from a floor or area name, alias, device id, or entity id."""
floor_registry = fr.async_get(hass)
if floor := floor_registry.async_get_floor_by_name(str(lookup_value)):
lookup_str = str(lookup_value)
if floor := floor_registry.async_get_floor_by_name(lookup_str):
return floor.floor_id
floors_list = floor_registry.async_get_floors_by_alias(lookup_str)
if floors_list:
return floors_list[0].floor_id
if aid := area_id(hass, lookup_value):
area_reg = area_registry.async_get(hass)
@ -1541,10 +1545,14 @@ def areas(hass: HomeAssistant) -> Iterable[str | None]:
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
"""Get the area ID from an area name, device id, or entity id."""
"""Get the area ID from an area name, alias, device id, or entity id."""
area_reg = area_registry.async_get(hass)
if area := area_reg.async_get_area_by_name(str(lookup_value)):
lookup_str = str(lookup_value)
if area := area_reg.async_get_area_by_name(lookup_str):
return area.id
areas_list = area_reg.async_get_areas_by_alias(lookup_str)
if areas_list:
return areas_list[0].id
ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)

View File

@ -494,6 +494,29 @@ async def test_async_get_area_by_name(area_registry: ar.AreaRegistry) -> None:
assert area_registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1"
async def test_async_get_areas_by_alias(
area_registry: ar.AreaRegistry,
) -> None:
"""Make sure we can get the areas by alias."""
area1 = area_registry.async_create("Mock1", aliases=("alias_1", "alias_2"))
area2 = area_registry.async_create("Mock2", aliases=("alias_1", "alias_3"))
assert len(area_registry.areas) == 2
alias1_list = area_registry.async_get_areas_by_alias("A l i a s_1")
alias2_list = area_registry.async_get_areas_by_alias("A l i a s_2")
alias3_list = area_registry.async_get_areas_by_alias("A l i a s_3")
assert len(alias1_list) == 2
assert len(alias2_list) == 1
assert len(alias3_list) == 1
assert area1 in alias1_list
assert area1 in alias2_list
assert area2 in alias1_list
assert area2 in alias3_list
async def test_async_get_area_by_name_not_found(area_registry: ar.AreaRegistry) -> None:
"""Make sure we return None for non-existent areas."""
area_registry.async_create("Mock1")

View File

@ -327,7 +327,7 @@ async def test_loading_floors_from_storage(
assert len(registry.floors) == 1
async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None:
async def test_getting_floor_by_name(floor_registry: fr.FloorRegistry) -> None:
"""Make sure we can get the floors by name."""
floor = floor_registry.async_create("First floor")
floor2 = floor_registry.async_get_floor_by_name("first floor")
@ -341,6 +341,27 @@ async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None:
assert get_floor == floor
async def test_async_get_floors_by_alias(
floor_registry: fr.FloorRegistry,
) -> None:
"""Make sure we can get the floors by alias."""
floor1 = floor_registry.async_create("First floor", aliases=("alias_1", "alias_2"))
floor2 = floor_registry.async_create("Second floor", aliases=("alias_1", "alias_3"))
alias1_list = floor_registry.async_get_floors_by_alias("A l i a s_1")
alias2_list = floor_registry.async_get_floors_by_alias("A l i a s_2")
alias3_list = floor_registry.async_get_floors_by_alias("A l i a s_3")
assert len(alias1_list) == 2
assert len(alias2_list) == 1
assert len(alias3_list) == 1
assert floor1 in alias1_list
assert floor1 in alias2_list
assert floor2 in alias1_list
assert floor2 in alias3_list
async def test_async_get_floor_by_name_not_found(
floor_registry: fr.FloorRegistry,
) -> None: