mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Get area and floor by alias (#126150)
* Add possibility to get area by alias * Add ability to get floor by alias * Moved alias lookup to separate function, adjusted templates. * Changed registry to return all areas/floors with given alias * Use normalize_name from normalized_name_base_registry
This commit is contained in:
parent
c30f17f592
commit
dea00fac3f
@ -20,6 +20,7 @@ from .json import json_bytes, json_fragment
|
|||||||
from .normalized_name_base_registry import (
|
from .normalized_name_base_registry import (
|
||||||
NormalizedNameBaseRegistryEntry,
|
NormalizedNameBaseRegistryEntry,
|
||||||
NormalizedNameBaseRegistryItems,
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
)
|
)
|
||||||
from .registry import BaseRegistry, RegistryIndexType
|
from .registry import BaseRegistry, RegistryIndexType
|
||||||
from .singleton import singleton
|
from .singleton import singleton
|
||||||
@ -169,6 +170,7 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._labels_index: RegistryIndexType = defaultdict(dict)
|
self._labels_index: RegistryIndexType = defaultdict(dict)
|
||||||
self._floors_index: RegistryIndexType = defaultdict(dict)
|
self._floors_index: RegistryIndexType = defaultdict(dict)
|
||||||
|
self._aliases_index: RegistryIndexType = defaultdict(dict)
|
||||||
|
|
||||||
def _index_entry(self, key: str, entry: AreaEntry) -> None:
|
def _index_entry(self, key: str, entry: AreaEntry) -> None:
|
||||||
"""Index an entry."""
|
"""Index an entry."""
|
||||||
@ -177,6 +179,9 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
|||||||
self._floors_index[entry.floor_id][key] = True
|
self._floors_index[entry.floor_id][key] = True
|
||||||
for label in entry.labels:
|
for label in entry.labels:
|
||||||
self._labels_index[label][key] = True
|
self._labels_index[label][key] = True
|
||||||
|
for alias in entry.aliases:
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
self._aliases_index[normalized_alias][key] = True
|
||||||
|
|
||||||
def _unindex_entry(
|
def _unindex_entry(
|
||||||
self, key: str, replacement_entry: AreaEntry | None = None
|
self, key: str, replacement_entry: AreaEntry | None = None
|
||||||
@ -184,6 +189,10 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
|||||||
# always call base class before other indices
|
# always call base class before other indices
|
||||||
super()._unindex_entry(key, replacement_entry)
|
super()._unindex_entry(key, replacement_entry)
|
||||||
entry = self.data[key]
|
entry = self.data[key]
|
||||||
|
if aliases := entry.aliases:
|
||||||
|
for alias in aliases:
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
|
||||||
if labels := entry.labels:
|
if labels := entry.labels:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
self._unindex_entry_value(key, label, self._labels_index)
|
self._unindex_entry_value(key, label, self._labels_index)
|
||||||
@ -200,6 +209,12 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
|||||||
data = self.data
|
data = self.data
|
||||||
return [data[key] for key in self._floors_index.get(floor, ())]
|
return [data[key] for key in self._floors_index.get(floor, ())]
|
||||||
|
|
||||||
|
def get_areas_for_alias(self, alias: str) -> list[AreaEntry]:
|
||||||
|
"""Get areas for alias."""
|
||||||
|
data = self.data
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||||
"""Class to hold a registry of areas."""
|
"""Class to hold a registry of areas."""
|
||||||
@ -232,6 +247,11 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
|||||||
"""Get area by name."""
|
"""Get area by name."""
|
||||||
return self.areas.get_by_name(name)
|
return self.areas.get_by_name(name)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_areas_by_alias(self, alias: str) -> list[AreaEntry]:
|
||||||
|
"""Get areas by alias."""
|
||||||
|
return self.areas.get_areas_for_alias(alias)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_areas(self) -> Iterable[AreaEntry]:
|
def async_list_areas(self) -> Iterable[AreaEntry]:
|
||||||
"""Get all areas."""
|
"""Get all areas."""
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -16,8 +17,9 @@ from homeassistant.util.hass_dict import HassKey
|
|||||||
from .normalized_name_base_registry import (
|
from .normalized_name_base_registry import (
|
||||||
NormalizedNameBaseRegistryEntry,
|
NormalizedNameBaseRegistryEntry,
|
||||||
NormalizedNameBaseRegistryItems,
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
)
|
)
|
||||||
from .registry import BaseRegistry
|
from .registry import BaseRegistry, RegistryIndexType
|
||||||
from .singleton import singleton
|
from .singleton import singleton
|
||||||
from .storage import Store
|
from .storage import Store
|
||||||
from .typing import UNDEFINED, UndefinedType
|
from .typing import UNDEFINED, UndefinedType
|
||||||
@ -92,10 +94,43 @@ class FloorRegistryStore(Store[FloorRegistryStoreData]):
|
|||||||
return old_data # type: ignore[return-value]
|
return old_data # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
class FloorRegistryItems(NormalizedNameBaseRegistryItems[FloorEntry]):
|
||||||
|
"""Class to hold floor registry items."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the floor registry items."""
|
||||||
|
super().__init__()
|
||||||
|
self._aliases_index: RegistryIndexType = defaultdict(dict)
|
||||||
|
|
||||||
|
def _index_entry(self, key: str, entry: FloorEntry) -> None:
|
||||||
|
"""Index an entry."""
|
||||||
|
super()._index_entry(key, entry)
|
||||||
|
for alias in entry.aliases:
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
self._aliases_index[normalized_alias][key] = True
|
||||||
|
|
||||||
|
def _unindex_entry(
|
||||||
|
self, key: str, replacement_entry: FloorEntry | None = None
|
||||||
|
) -> None:
|
||||||
|
# always call base class before other indices
|
||||||
|
super()._unindex_entry(key, replacement_entry)
|
||||||
|
entry = self.data[key]
|
||||||
|
if aliases := entry.aliases:
|
||||||
|
for alias in aliases:
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
self._unindex_entry_value(key, normalized_alias, self._aliases_index)
|
||||||
|
|
||||||
|
def get_floors_for_alias(self, alias: str) -> list[FloorEntry]:
|
||||||
|
"""Get floors for alias."""
|
||||||
|
data = self.data
|
||||||
|
normalized_alias = normalize_name(alias)
|
||||||
|
return [data[key] for key in self._aliases_index.get(normalized_alias, ())]
|
||||||
|
|
||||||
|
|
||||||
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||||
"""Class to hold a registry of floors."""
|
"""Class to hold a registry of floors."""
|
||||||
|
|
||||||
floors: NormalizedNameBaseRegistryItems[FloorEntry]
|
floors: FloorRegistryItems
|
||||||
_floor_data: dict[str, FloorEntry]
|
_floor_data: dict[str, FloorEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
@ -123,6 +158,11 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
|||||||
"""Get floor by name."""
|
"""Get floor by name."""
|
||||||
return self.floors.get_by_name(name)
|
return self.floors.get_by_name(name)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_floors_by_alias(self, alias: str) -> list[FloorEntry]:
|
||||||
|
"""Get floors by alias."""
|
||||||
|
return self.floors.get_floors_for_alias(alias)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_floors(self) -> Iterable[FloorEntry]:
|
def async_list_floors(self) -> Iterable[FloorEntry]:
|
||||||
"""Get all floors."""
|
"""Get all floors."""
|
||||||
@ -226,7 +266,7 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
|||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load the floor registry."""
|
"""Load the floor registry."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
floors = NormalizedNameBaseRegistryItems[FloorEntry]()
|
floors = FloorRegistryItems()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for floor in data["floors"]:
|
for floor in data["floors"]:
|
||||||
|
@ -1478,10 +1478,14 @@ def floors(hass: HomeAssistant) -> Iterable[str | None]:
|
|||||||
|
|
||||||
|
|
||||||
def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
|
def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
|
||||||
"""Get the floor ID from a floor name."""
|
"""Get the floor ID from a floor or area name, alias, device id, or entity id."""
|
||||||
floor_registry = fr.async_get(hass)
|
floor_registry = fr.async_get(hass)
|
||||||
if floor := floor_registry.async_get_floor_by_name(str(lookup_value)):
|
lookup_str = str(lookup_value)
|
||||||
|
if floor := floor_registry.async_get_floor_by_name(lookup_str):
|
||||||
return floor.floor_id
|
return floor.floor_id
|
||||||
|
floors_list = floor_registry.async_get_floors_by_alias(lookup_str)
|
||||||
|
if floors_list:
|
||||||
|
return floors_list[0].floor_id
|
||||||
|
|
||||||
if aid := area_id(hass, lookup_value):
|
if aid := area_id(hass, lookup_value):
|
||||||
area_reg = area_registry.async_get(hass)
|
area_reg = area_registry.async_get(hass)
|
||||||
@ -1541,10 +1545,14 @@ def areas(hass: HomeAssistant) -> Iterable[str | None]:
|
|||||||
|
|
||||||
|
|
||||||
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
|
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
|
||||||
"""Get the area ID from an area name, device id, or entity id."""
|
"""Get the area ID from an area name, alias, device id, or entity id."""
|
||||||
area_reg = area_registry.async_get(hass)
|
area_reg = area_registry.async_get(hass)
|
||||||
if area := area_reg.async_get_area_by_name(str(lookup_value)):
|
lookup_str = str(lookup_value)
|
||||||
|
if area := area_reg.async_get_area_by_name(lookup_str):
|
||||||
return area.id
|
return area.id
|
||||||
|
areas_list = area_reg.async_get_areas_by_alias(lookup_str)
|
||||||
|
if areas_list:
|
||||||
|
return areas_list[0].id
|
||||||
|
|
||||||
ent_reg = entity_registry.async_get(hass)
|
ent_reg = entity_registry.async_get(hass)
|
||||||
dev_reg = device_registry.async_get(hass)
|
dev_reg = device_registry.async_get(hass)
|
||||||
|
@ -494,6 +494,29 @@ async def test_async_get_area_by_name(area_registry: ar.AreaRegistry) -> None:
|
|||||||
assert area_registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1"
|
assert area_registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_areas_by_alias(
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Make sure we can get the areas by alias."""
|
||||||
|
area1 = area_registry.async_create("Mock1", aliases=("alias_1", "alias_2"))
|
||||||
|
area2 = area_registry.async_create("Mock2", aliases=("alias_1", "alias_3"))
|
||||||
|
|
||||||
|
assert len(area_registry.areas) == 2
|
||||||
|
|
||||||
|
alias1_list = area_registry.async_get_areas_by_alias("A l i a s_1")
|
||||||
|
alias2_list = area_registry.async_get_areas_by_alias("A l i a s_2")
|
||||||
|
alias3_list = area_registry.async_get_areas_by_alias("A l i a s_3")
|
||||||
|
|
||||||
|
assert len(alias1_list) == 2
|
||||||
|
assert len(alias2_list) == 1
|
||||||
|
assert len(alias3_list) == 1
|
||||||
|
|
||||||
|
assert area1 in alias1_list
|
||||||
|
assert area1 in alias2_list
|
||||||
|
assert area2 in alias1_list
|
||||||
|
assert area2 in alias3_list
|
||||||
|
|
||||||
|
|
||||||
async def test_async_get_area_by_name_not_found(area_registry: ar.AreaRegistry) -> None:
|
async def test_async_get_area_by_name_not_found(area_registry: ar.AreaRegistry) -> None:
|
||||||
"""Make sure we return None for non-existent areas."""
|
"""Make sure we return None for non-existent areas."""
|
||||||
area_registry.async_create("Mock1")
|
area_registry.async_create("Mock1")
|
||||||
|
@ -327,7 +327,7 @@ async def test_loading_floors_from_storage(
|
|||||||
assert len(registry.floors) == 1
|
assert len(registry.floors) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None:
|
async def test_getting_floor_by_name(floor_registry: fr.FloorRegistry) -> None:
|
||||||
"""Make sure we can get the floors by name."""
|
"""Make sure we can get the floors by name."""
|
||||||
floor = floor_registry.async_create("First floor")
|
floor = floor_registry.async_create("First floor")
|
||||||
floor2 = floor_registry.async_get_floor_by_name("first floor")
|
floor2 = floor_registry.async_get_floor_by_name("first floor")
|
||||||
@ -341,6 +341,27 @@ async def test_getting_floor(floor_registry: fr.FloorRegistry) -> None:
|
|||||||
assert get_floor == floor
|
assert get_floor == floor
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_floors_by_alias(
|
||||||
|
floor_registry: fr.FloorRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Make sure we can get the floors by alias."""
|
||||||
|
floor1 = floor_registry.async_create("First floor", aliases=("alias_1", "alias_2"))
|
||||||
|
floor2 = floor_registry.async_create("Second floor", aliases=("alias_1", "alias_3"))
|
||||||
|
|
||||||
|
alias1_list = floor_registry.async_get_floors_by_alias("A l i a s_1")
|
||||||
|
alias2_list = floor_registry.async_get_floors_by_alias("A l i a s_2")
|
||||||
|
alias3_list = floor_registry.async_get_floors_by_alias("A l i a s_3")
|
||||||
|
|
||||||
|
assert len(alias1_list) == 2
|
||||||
|
assert len(alias2_list) == 1
|
||||||
|
assert len(alias3_list) == 1
|
||||||
|
|
||||||
|
assert floor1 in alias1_list
|
||||||
|
assert floor1 in alias2_list
|
||||||
|
assert floor2 in alias1_list
|
||||||
|
assert floor2 in alias3_list
|
||||||
|
|
||||||
|
|
||||||
async def test_async_get_floor_by_name_not_found(
|
async def test_async_get_floor_by_name_not_found(
|
||||||
floor_registry: fr.FloorRegistry,
|
floor_registry: fr.FloorRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user