mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Improve code of normalized name registry (#125282)
This commit is contained in:
parent
4060705d87
commit
98a86c7636
@ -5,12 +5,12 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.dt import utc_from_timestamp, utcnow
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
@ -20,7 +20,6 @@ from .json import json_bytes, json_fragment
|
||||
from .normalized_name_base_registry import (
|
||||
NormalizedNameBaseRegistryEntry,
|
||||
NormalizedNameBaseRegistryItems,
|
||||
normalize_name,
|
||||
)
|
||||
from .registry import BaseRegistry, RegistryIndexType
|
||||
from .singleton import singleton
|
||||
@ -63,7 +62,7 @@ class EventAreaRegistryUpdatedData(TypedDict):
|
||||
area_id: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AreaEntry(NormalizedNameBaseRegistryEntry):
|
||||
"""Area Registry Entry."""
|
||||
|
||||
@ -71,7 +70,7 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
|
||||
floor_id: str | None
|
||||
icon: str | None
|
||||
id: str
|
||||
labels: set[str] = dataclasses.field(default_factory=set)
|
||||
labels: set[str] = field(default_factory=set)
|
||||
picture: str | None
|
||||
|
||||
@cached_property
|
||||
@ -225,6 +224,10 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
return area
|
||||
return self.async_create(name)
|
||||
|
||||
def _generate_id(self, name: str) -> str:
|
||||
"""Generate area ID."""
|
||||
return self.areas.generate_id_from_name(name)
|
||||
|
||||
@callback
|
||||
def async_create(
|
||||
self,
|
||||
@ -238,28 +241,28 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
) -> AreaEntry:
|
||||
"""Create a new area."""
|
||||
self.hass.verify_event_loop_thread("area_registry.async_create")
|
||||
normalized_name = normalize_name(name)
|
||||
|
||||
if self.async_get_area_by_name(name):
|
||||
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
|
||||
if area := self.async_get_area_by_name(name):
|
||||
raise ValueError(
|
||||
f"The name {name} ({area.normalized_name}) is already in use"
|
||||
)
|
||||
|
||||
area_id = self._generate_area_id(name)
|
||||
area = AreaEntry(
|
||||
aliases=aliases or set(),
|
||||
floor_id=floor_id,
|
||||
icon=icon,
|
||||
id=area_id,
|
||||
id=self._generate_id(name),
|
||||
labels=labels or set(),
|
||||
name=name,
|
||||
normalized_name=normalized_name,
|
||||
picture=picture,
|
||||
)
|
||||
assert area.id is not None
|
||||
self.areas[area.id] = area
|
||||
area_id = area.id
|
||||
self.areas[area_id] = area
|
||||
self.async_schedule_save()
|
||||
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_AREA_REGISTRY_UPDATED,
|
||||
EventAreaRegistryUpdatedData(action="create", area_id=area.id),
|
||||
EventAreaRegistryUpdatedData(action="create", area_id=area_id),
|
||||
)
|
||||
return area
|
||||
|
||||
@ -342,7 +345,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
|
||||
if name is not UNDEFINED and name != old.name:
|
||||
new_values["name"] = name
|
||||
new_values["normalized_name"] = normalize_name(name)
|
||||
|
||||
if not new_values:
|
||||
return old
|
||||
@ -366,7 +368,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
if data is not None:
|
||||
for area in data["areas"]:
|
||||
assert area["name"] is not None and area["id"] is not None
|
||||
normalized_name = normalize_name(area["name"])
|
||||
areas[area["id"]] = AreaEntry(
|
||||
aliases=set(area["aliases"]),
|
||||
floor_id=area["floor_id"],
|
||||
@ -374,7 +375,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
id=area["id"],
|
||||
labels=set(area["labels"]),
|
||||
name=area["name"],
|
||||
normalized_name=normalized_name,
|
||||
picture=area["picture"],
|
||||
created_at=datetime.fromisoformat(area["created_at"]),
|
||||
modified_at=datetime.fromisoformat(area["modified_at"]),
|
||||
@ -403,15 +403,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||
]
|
||||
}
|
||||
|
||||
def _generate_area_id(self, name: str) -> str:
|
||||
"""Generate area ID."""
|
||||
suggestion = suggestion_base = slugify(name)
|
||||
tries = 1
|
||||
while suggestion in self.areas:
|
||||
tries += 1
|
||||
suggestion = f"{suggestion_base}_{tries}"
|
||||
return suggestion
|
||||
|
||||
@callback
|
||||
def _async_setup_cleanup(self) -> None:
|
||||
"""Set up the area registry cleanup."""
|
||||
|
@ -9,7 +9,6 @@ from datetime import datetime
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.dt import utc_from_timestamp, utcnow
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
@ -17,7 +16,6 @@ from homeassistant.util.hass_dict import HassKey
|
||||
from .normalized_name_base_registry import (
|
||||
NormalizedNameBaseRegistryEntry,
|
||||
NormalizedNameBaseRegistryItems,
|
||||
normalize_name,
|
||||
)
|
||||
from .registry import BaseRegistry
|
||||
from .singleton import singleton
|
||||
@ -130,15 +128,9 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||
"""Get all floors."""
|
||||
return self.floors.values()
|
||||
|
||||
@callback
|
||||
def _generate_id(self, name: str) -> str:
|
||||
"""Generate floor ID."""
|
||||
suggestion = suggestion_base = slugify(name)
|
||||
tries = 1
|
||||
while suggestion in self.floors:
|
||||
tries += 1
|
||||
suggestion = f"{suggestion_base}_{tries}"
|
||||
return suggestion
|
||||
return self.floors.generate_id_from_name(name)
|
||||
|
||||
@callback
|
||||
def async_create(
|
||||
@ -151,30 +143,26 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||
) -> FloorEntry:
|
||||
"""Create a new floor."""
|
||||
self.hass.verify_event_loop_thread("floor_registry.async_create")
|
||||
|
||||
if floor := self.async_get_floor_by_name(name):
|
||||
raise ValueError(
|
||||
f"The name {name} ({floor.normalized_name}) is already in use"
|
||||
)
|
||||
|
||||
normalized_name = normalize_name(name)
|
||||
|
||||
floor = FloorEntry(
|
||||
aliases=aliases or set(),
|
||||
icon=icon,
|
||||
floor_id=self._generate_id(name),
|
||||
name=name,
|
||||
normalized_name=normalized_name,
|
||||
level=level,
|
||||
)
|
||||
floor_id = floor.floor_id
|
||||
self.floors[floor_id] = floor
|
||||
self.async_schedule_save()
|
||||
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_FLOOR_REGISTRY_UPDATED,
|
||||
EventFloorRegistryUpdatedData(
|
||||
action="create",
|
||||
floor_id=floor_id,
|
||||
),
|
||||
EventFloorRegistryUpdatedData(action="create", floor_id=floor_id),
|
||||
)
|
||||
return floor
|
||||
|
||||
@ -215,7 +203,6 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||
}
|
||||
if name is not UNDEFINED and name != old.name:
|
||||
changes["name"] = name
|
||||
changes["normalized_name"] = normalize_name(name)
|
||||
|
||||
if not changes:
|
||||
return old
|
||||
@ -243,14 +230,12 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||
|
||||
if data is not None:
|
||||
for floor in data["floors"]:
|
||||
normalized_name = normalize_name(floor["name"])
|
||||
floors[floor["floor_id"]] = FloorEntry(
|
||||
aliases=set(floor["aliases"]),
|
||||
icon=floor["icon"],
|
||||
floor_id=floor["floor_id"],
|
||||
name=floor["name"],
|
||||
level=floor["level"],
|
||||
normalized_name=normalized_name,
|
||||
created_at=datetime.fromisoformat(floor["created_at"]),
|
||||
modified_at=datetime.fromisoformat(floor["modified_at"]),
|
||||
)
|
||||
|
@ -9,7 +9,6 @@ from datetime import datetime
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.dt import utc_from_timestamp, utcnow
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
@ -17,7 +16,6 @@ from homeassistant.util.hass_dict import HassKey
|
||||
from .normalized_name_base_registry import (
|
||||
NormalizedNameBaseRegistryEntry,
|
||||
NormalizedNameBaseRegistryItems,
|
||||
normalize_name,
|
||||
)
|
||||
from .registry import BaseRegistry
|
||||
from .singleton import singleton
|
||||
@ -130,15 +128,9 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
||||
"""Get all labels."""
|
||||
return self.labels.values()
|
||||
|
||||
@callback
|
||||
def _generate_id(self, name: str) -> str:
|
||||
"""Initialize ID."""
|
||||
suggestion = suggestion_base = slugify(name)
|
||||
tries = 1
|
||||
while suggestion in self.labels:
|
||||
tries += 1
|
||||
suggestion = f"{suggestion_base}_{tries}"
|
||||
return suggestion
|
||||
"""Generate label ID."""
|
||||
return self.labels.generate_id_from_name(name)
|
||||
|
||||
@callback
|
||||
def async_create(
|
||||
@ -151,30 +143,26 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
||||
) -> LabelEntry:
|
||||
"""Create a new label."""
|
||||
self.hass.verify_event_loop_thread("label_registry.async_create")
|
||||
|
||||
if label := self.async_get_label_by_name(name):
|
||||
raise ValueError(
|
||||
f"The name {name} ({label.normalized_name}) is already in use"
|
||||
)
|
||||
|
||||
normalized_name = normalize_name(name)
|
||||
|
||||
label = LabelEntry(
|
||||
color=color,
|
||||
description=description,
|
||||
icon=icon,
|
||||
label_id=self._generate_id(name),
|
||||
name=name,
|
||||
normalized_name=normalized_name,
|
||||
)
|
||||
label_id = label.label_id
|
||||
self.labels[label_id] = label
|
||||
self.async_schedule_save()
|
||||
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_LABEL_REGISTRY_UPDATED,
|
||||
EventLabelRegistryUpdatedData(
|
||||
action="create",
|
||||
label_id=label_id,
|
||||
),
|
||||
EventLabelRegistryUpdatedData(action="create", label_id=label_id),
|
||||
)
|
||||
return label
|
||||
|
||||
@ -216,7 +204,6 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
||||
|
||||
if name is not UNDEFINED and name != old.name:
|
||||
changes["name"] = name
|
||||
changes["normalized_name"] = normalize_name(name)
|
||||
|
||||
if not changes:
|
||||
return old
|
||||
@ -244,14 +231,12 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
||||
|
||||
if data is not None:
|
||||
for label in data["labels"]:
|
||||
normalized_name = normalize_name(label["name"])
|
||||
labels[label["label_id"]] = LabelEntry(
|
||||
color=label["color"],
|
||||
description=label["description"],
|
||||
icon=label["icon"],
|
||||
label_id=label["label_id"],
|
||||
name=label["name"],
|
||||
normalized_name=normalized_name,
|
||||
created_at=datetime.fromisoformat(label["created_at"]),
|
||||
modified_at=datetime.fromisoformat(label["modified_at"]),
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util import dt as dt_util, slugify
|
||||
|
||||
from .registry import BaseRegistryItems
|
||||
|
||||
@ -14,10 +14,14 @@ class NormalizedNameBaseRegistryEntry:
|
||||
"""Normalized Name Base Registry Entry."""
|
||||
|
||||
name: str
|
||||
normalized_name: str
|
||||
normalized_name: str = field(init=False)
|
||||
created_at: datetime = field(default_factory=dt_util.utcnow)
|
||||
modified_at: datetime = field(default_factory=dt_util.utcnow)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Post init."""
|
||||
object.__setattr__(self, "normalized_name", normalize_name(self.name))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def normalize_name(name: str) -> str:
|
||||
@ -43,7 +47,7 @@ class NormalizedNameBaseRegistryItems[_VT: NormalizedNameBaseRegistryEntry](
|
||||
old_entry = self.data[key]
|
||||
if (
|
||||
replacement_entry is not None
|
||||
and (normalized_name := normalize_name(replacement_entry.name))
|
||||
and (normalized_name := replacement_entry.normalized_name)
|
||||
!= old_entry.normalized_name
|
||||
and normalized_name in self._normalized_names
|
||||
):
|
||||
@ -53,8 +57,17 @@ class NormalizedNameBaseRegistryItems[_VT: NormalizedNameBaseRegistryEntry](
|
||||
del self._normalized_names[old_entry.normalized_name]
|
||||
|
||||
def _index_entry(self, key: str, entry: _VT) -> None:
|
||||
self._normalized_names[normalize_name(entry.name)] = entry
|
||||
self._normalized_names[entry.normalized_name] = entry
|
||||
|
||||
def get_by_name(self, name: str) -> _VT | None:
|
||||
"""Get entry by name."""
|
||||
return self._normalized_names.get(normalize_name(name))
|
||||
|
||||
def generate_id_from_name(self, name: str) -> str:
|
||||
"""Generate ID from name."""
|
||||
suggestion = suggestion_base = slugify(name)
|
||||
tries = 1
|
||||
while suggestion in self:
|
||||
tries += 1
|
||||
suggestion = f"{suggestion_base}_{tries}"
|
||||
return suggestion
|
||||
|
@ -45,7 +45,6 @@ async def test_create_area(
|
||||
id=ANY,
|
||||
labels=set(),
|
||||
name="mock",
|
||||
normalized_name=ANY,
|
||||
picture=None,
|
||||
created_at=utcnow(),
|
||||
modified_at=utcnow(),
|
||||
@ -77,7 +76,6 @@ async def test_create_area(
|
||||
id=ANY,
|
||||
labels={"label1", "label2"},
|
||||
name="mock 2",
|
||||
normalized_name=ANY,
|
||||
picture="/image/example.png",
|
||||
created_at=utcnow(),
|
||||
modified_at=utcnow(),
|
||||
@ -196,7 +194,6 @@ async def test_update_area(
|
||||
id=ANY,
|
||||
labels={"label1", "label2"},
|
||||
name="mock1",
|
||||
normalized_name=ANY,
|
||||
picture="/image/example.png",
|
||||
created_at=created_at,
|
||||
modified_at=modified_at,
|
||||
|
@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import area_registry as ar, floor_registry as fr
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import ANY, async_capture_events, flush_store
|
||||
from tests.common import async_capture_events, flush_store
|
||||
|
||||
|
||||
async def test_list_floors(floor_registry: fr.FloorRegistry) -> None:
|
||||
@ -43,7 +43,6 @@ async def test_create_floor(
|
||||
level=1,
|
||||
created_at=utcnow(),
|
||||
modified_at=utcnow(),
|
||||
normalized_name=ANY,
|
||||
)
|
||||
|
||||
assert len(floor_registry.floors) == 1
|
||||
@ -145,7 +144,6 @@ async def test_update_floor(
|
||||
level=None,
|
||||
created_at=created_at,
|
||||
modified_at=created_at,
|
||||
normalized_name=ANY,
|
||||
)
|
||||
assert len(floor_registry.floors) == 1
|
||||
|
||||
@ -169,7 +167,6 @@ async def test_update_floor(
|
||||
level=2,
|
||||
created_at=created_at,
|
||||
modified_at=modified_at,
|
||||
normalized_name=ANY,
|
||||
)
|
||||
|
||||
assert len(floor_registry.floors) == 1
|
||||
|
@ -16,7 +16,7 @@ from homeassistant.helpers import (
|
||||
)
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import ANY, MockConfigEntry, async_capture_events, flush_store
|
||||
from tests.common import MockConfigEntry, async_capture_events, flush_store
|
||||
|
||||
|
||||
async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
|
||||
@ -46,7 +46,6 @@ async def test_create_label(
|
||||
description="This label is for testing",
|
||||
created_at=utcnow(),
|
||||
modified_at=utcnow(),
|
||||
normalized_name=ANY,
|
||||
)
|
||||
|
||||
assert len(label_registry.labels) == 1
|
||||
@ -147,7 +146,6 @@ async def test_update_label(
|
||||
description=None,
|
||||
created_at=created_at,
|
||||
modified_at=created_at,
|
||||
normalized_name=ANY,
|
||||
)
|
||||
|
||||
modified_at = datetime.fromisoformat("2024-02-01T01:00:00+00:00")
|
||||
@ -169,7 +167,6 @@ async def test_update_label(
|
||||
description="Updated description",
|
||||
created_at=created_at,
|
||||
modified_at=modified_at,
|
||||
normalized_name=ANY,
|
||||
)
|
||||
assert len(label_registry.labels) == 1
|
||||
|
||||
|
@ -26,18 +26,14 @@ def test_registry_items(
|
||||
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
|
||||
) -> None:
|
||||
"""Test registry items."""
|
||||
entry = NormalizedNameBaseRegistryEntry(
|
||||
name="Hello World", normalized_name="helloworld"
|
||||
)
|
||||
entry = NormalizedNameBaseRegistryEntry(name="Hello World")
|
||||
registry_items["key"] = entry
|
||||
assert registry_items["key"] == entry
|
||||
assert list(registry_items.values()) == [entry]
|
||||
assert registry_items.get_by_name("Hello World") == entry
|
||||
|
||||
# test update entry
|
||||
entry2 = NormalizedNameBaseRegistryEntry(
|
||||
name="Hello World 2", normalized_name="helloworld2"
|
||||
)
|
||||
entry2 = NormalizedNameBaseRegistryEntry(name="Hello World 2")
|
||||
registry_items["key"] = entry2
|
||||
assert registry_items["key"] == entry2
|
||||
assert list(registry_items.values()) == [entry2]
|
||||
@ -53,16 +49,12 @@ def test_key_already_in_use(
|
||||
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
|
||||
) -> None:
|
||||
"""Test key already in use."""
|
||||
entry = NormalizedNameBaseRegistryEntry(
|
||||
name="Hello World", normalized_name="helloworld"
|
||||
)
|
||||
entry = NormalizedNameBaseRegistryEntry(name="Hello World")
|
||||
registry_items["key"] = entry
|
||||
|
||||
# should raise ValueError if we update a
|
||||
# key with a entry with the same normalized name
|
||||
entry = NormalizedNameBaseRegistryEntry(
|
||||
name="Hello World 2", normalized_name="helloworld2"
|
||||
)
|
||||
entry = NormalizedNameBaseRegistryEntry(name="Hello World 2")
|
||||
registry_items["key2"] = entry
|
||||
with pytest.raises(ValueError):
|
||||
registry_items["key"] = entry
|
||||
|
@ -119,7 +119,6 @@ def floor_area_mock(hass: HomeAssistant) -> None:
|
||||
id="test-area",
|
||||
name="Test area",
|
||||
aliases={},
|
||||
normalized_name="test-area",
|
||||
floor_id="test-floor",
|
||||
icon=None,
|
||||
picture=None,
|
||||
@ -128,7 +127,6 @@ def floor_area_mock(hass: HomeAssistant) -> None:
|
||||
id="area-a",
|
||||
name="Area A",
|
||||
aliases={},
|
||||
normalized_name="area-a",
|
||||
floor_id="floor-a",
|
||||
icon=None,
|
||||
picture=None,
|
||||
@ -282,7 +280,6 @@ def label_mock(hass: HomeAssistant) -> None:
|
||||
id="area-with-labels",
|
||||
name="Area with labels",
|
||||
aliases={},
|
||||
normalized_name="with_labels",
|
||||
floor_id=None,
|
||||
icon=None,
|
||||
labels={"label_area"},
|
||||
@ -292,7 +289,6 @@ def label_mock(hass: HomeAssistant) -> None:
|
||||
id="area-no-labels",
|
||||
name="Area without labels",
|
||||
aliases={},
|
||||
normalized_name="without_labels",
|
||||
floor_id=None,
|
||||
icon=None,
|
||||
labels=set(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user