Improve code of normalized name registry (#125282)

This commit is contained in:
Artur Pragacz 2024-10-01 18:20:52 +02:00 committed by GitHub
parent 4060705d87
commit 98a86c7636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 48 additions and 95 deletions

View File

@ -5,12 +5,12 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass, field
from datetime import datetime
from functools import cached_property
from typing import Any, Literal, TypedDict
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.dt import utc_from_timestamp, utcnow
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
@ -20,7 +20,6 @@ from .json import json_bytes, json_fragment
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry, RegistryIndexType
from .singleton import singleton
@ -63,7 +62,7 @@ class EventAreaRegistryUpdatedData(TypedDict):
area_id: str
@dataclasses.dataclass(frozen=True, kw_only=True)
@dataclass(frozen=True, kw_only=True)
class AreaEntry(NormalizedNameBaseRegistryEntry):
"""Area Registry Entry."""
@ -71,7 +70,7 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
floor_id: str | None
icon: str | None
id: str
labels: set[str] = dataclasses.field(default_factory=set)
labels: set[str] = field(default_factory=set)
picture: str | None
@cached_property
@ -225,6 +224,10 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
return area
return self.async_create(name)
def _generate_id(self, name: str) -> str:
"""Generate area ID."""
return self.areas.generate_id_from_name(name)
@callback
def async_create(
self,
@ -238,28 +241,28 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
) -> AreaEntry:
"""Create a new area."""
self.hass.verify_event_loop_thread("area_registry.async_create")
normalized_name = normalize_name(name)
if self.async_get_area_by_name(name):
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
if area := self.async_get_area_by_name(name):
raise ValueError(
f"The name {name} ({area.normalized_name}) is already in use"
)
area_id = self._generate_area_id(name)
area = AreaEntry(
aliases=aliases or set(),
floor_id=floor_id,
icon=icon,
id=area_id,
id=self._generate_id(name),
labels=labels or set(),
name=name,
normalized_name=normalized_name,
picture=picture,
)
assert area.id is not None
self.areas[area.id] = area
area_id = area.id
self.areas[area_id] = area
self.async_schedule_save()
self.hass.bus.async_fire_internal(
EVENT_AREA_REGISTRY_UPDATED,
EventAreaRegistryUpdatedData(action="create", area_id=area.id),
EventAreaRegistryUpdatedData(action="create", area_id=area_id),
)
return area
@ -342,7 +345,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
if name is not UNDEFINED and name != old.name:
new_values["name"] = name
new_values["normalized_name"] = normalize_name(name)
if not new_values:
return old
@ -366,7 +368,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
if data is not None:
for area in data["areas"]:
assert area["name"] is not None and area["id"] is not None
normalized_name = normalize_name(area["name"])
areas[area["id"]] = AreaEntry(
aliases=set(area["aliases"]),
floor_id=area["floor_id"],
@ -374,7 +375,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
id=area["id"],
labels=set(area["labels"]),
name=area["name"],
normalized_name=normalized_name,
picture=area["picture"],
created_at=datetime.fromisoformat(area["created_at"]),
modified_at=datetime.fromisoformat(area["modified_at"]),
@ -403,15 +403,6 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
]
}
def _generate_area_id(self, name: str) -> str:
"""Generate area ID."""
suggestion = suggestion_base = slugify(name)
tries = 1
while suggestion in self.areas:
tries += 1
suggestion = f"{suggestion_base}_{tries}"
return suggestion
@callback
def _async_setup_cleanup(self) -> None:
"""Set up the area registry cleanup."""

View File

@ -9,7 +9,6 @@ from datetime import datetime
from typing import Any, Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.dt import utc_from_timestamp, utcnow
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
@ -17,7 +16,6 @@ from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry
from .singleton import singleton
@ -130,15 +128,9 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Get all floors."""
return self.floors.values()
@callback
def _generate_id(self, name: str) -> str:
"""Generate floor ID."""
suggestion = suggestion_base = slugify(name)
tries = 1
while suggestion in self.floors:
tries += 1
suggestion = f"{suggestion_base}_{tries}"
return suggestion
return self.floors.generate_id_from_name(name)
@callback
def async_create(
@ -151,30 +143,26 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
) -> FloorEntry:
"""Create a new floor."""
self.hass.verify_event_loop_thread("floor_registry.async_create")
if floor := self.async_get_floor_by_name(name):
raise ValueError(
f"The name {name} ({floor.normalized_name}) is already in use"
)
normalized_name = normalize_name(name)
floor = FloorEntry(
aliases=aliases or set(),
icon=icon,
floor_id=self._generate_id(name),
name=name,
normalized_name=normalized_name,
level=level,
)
floor_id = floor.floor_id
self.floors[floor_id] = floor
self.async_schedule_save()
self.hass.bus.async_fire_internal(
EVENT_FLOOR_REGISTRY_UPDATED,
EventFloorRegistryUpdatedData(
action="create",
floor_id=floor_id,
),
EventFloorRegistryUpdatedData(action="create", floor_id=floor_id),
)
return floor
@ -215,7 +203,6 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
}
if name is not UNDEFINED and name != old.name:
changes["name"] = name
changes["normalized_name"] = normalize_name(name)
if not changes:
return old
@ -243,14 +230,12 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
if data is not None:
for floor in data["floors"]:
normalized_name = normalize_name(floor["name"])
floors[floor["floor_id"]] = FloorEntry(
aliases=set(floor["aliases"]),
icon=floor["icon"],
floor_id=floor["floor_id"],
name=floor["name"],
level=floor["level"],
normalized_name=normalized_name,
created_at=datetime.fromisoformat(floor["created_at"]),
modified_at=datetime.fromisoformat(floor["modified_at"]),
)

View File

@ -9,7 +9,6 @@ from datetime import datetime
from typing import Any, Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.dt import utc_from_timestamp, utcnow
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
@ -17,7 +16,6 @@ from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry
from .singleton import singleton
@ -130,15 +128,9 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"""Get all labels."""
return self.labels.values()
@callback
def _generate_id(self, name: str) -> str:
"""Initialize ID."""
suggestion = suggestion_base = slugify(name)
tries = 1
while suggestion in self.labels:
tries += 1
suggestion = f"{suggestion_base}_{tries}"
return suggestion
"""Generate label ID."""
return self.labels.generate_id_from_name(name)
@callback
def async_create(
@ -151,30 +143,26 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
) -> LabelEntry:
"""Create a new label."""
self.hass.verify_event_loop_thread("label_registry.async_create")
if label := self.async_get_label_by_name(name):
raise ValueError(
f"The name {name} ({label.normalized_name}) is already in use"
)
normalized_name = normalize_name(name)
label = LabelEntry(
color=color,
description=description,
icon=icon,
label_id=self._generate_id(name),
name=name,
normalized_name=normalized_name,
)
label_id = label.label_id
self.labels[label_id] = label
self.async_schedule_save()
self.hass.bus.async_fire_internal(
EVENT_LABEL_REGISTRY_UPDATED,
EventLabelRegistryUpdatedData(
action="create",
label_id=label_id,
),
EventLabelRegistryUpdatedData(action="create", label_id=label_id),
)
return label
@ -216,7 +204,6 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
if name is not UNDEFINED and name != old.name:
changes["name"] = name
changes["normalized_name"] = normalize_name(name)
if not changes:
return old
@ -244,14 +231,12 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
if data is not None:
for label in data["labels"]:
normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry(
color=label["color"],
description=label["description"],
icon=label["icon"],
label_id=label["label_id"],
name=label["name"],
normalized_name=normalized_name,
created_at=datetime.fromisoformat(label["created_at"]),
modified_at=datetime.fromisoformat(label["modified_at"]),
)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from functools import lru_cache
from homeassistant.util import dt as dt_util
from homeassistant.util import dt as dt_util, slugify
from .registry import BaseRegistryItems
@ -14,10 +14,14 @@ class NormalizedNameBaseRegistryEntry:
"""Normalized Name Base Registry Entry."""
name: str
normalized_name: str
normalized_name: str = field(init=False)
created_at: datetime = field(default_factory=dt_util.utcnow)
modified_at: datetime = field(default_factory=dt_util.utcnow)
def __post_init__(self) -> None:
"""Post init."""
object.__setattr__(self, "normalized_name", normalize_name(self.name))
@lru_cache(maxsize=1024)
def normalize_name(name: str) -> str:
@ -43,7 +47,7 @@ class NormalizedNameBaseRegistryItems[_VT: NormalizedNameBaseRegistryEntry](
old_entry = self.data[key]
if (
replacement_entry is not None
and (normalized_name := normalize_name(replacement_entry.name))
and (normalized_name := replacement_entry.normalized_name)
!= old_entry.normalized_name
and normalized_name in self._normalized_names
):
@ -53,8 +57,17 @@ class NormalizedNameBaseRegistryItems[_VT: NormalizedNameBaseRegistryEntry](
del self._normalized_names[old_entry.normalized_name]
def _index_entry(self, key: str, entry: _VT) -> None:
self._normalized_names[normalize_name(entry.name)] = entry
self._normalized_names[entry.normalized_name] = entry
def get_by_name(self, name: str) -> _VT | None:
"""Get entry by name."""
return self._normalized_names.get(normalize_name(name))
def generate_id_from_name(self, name: str) -> str:
"""Generate ID from name."""
suggestion = suggestion_base = slugify(name)
tries = 1
while suggestion in self:
tries += 1
suggestion = f"{suggestion_base}_{tries}"
return suggestion

View File

@ -45,7 +45,6 @@ async def test_create_area(
id=ANY,
labels=set(),
name="mock",
normalized_name=ANY,
picture=None,
created_at=utcnow(),
modified_at=utcnow(),
@ -77,7 +76,6 @@ async def test_create_area(
id=ANY,
labels={"label1", "label2"},
name="mock 2",
normalized_name=ANY,
picture="/image/example.png",
created_at=utcnow(),
modified_at=utcnow(),
@ -196,7 +194,6 @@ async def test_update_area(
id=ANY,
labels={"label1", "label2"},
name="mock1",
normalized_name=ANY,
picture="/image/example.png",
created_at=created_at,
modified_at=modified_at,

View File

@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import area_registry as ar, floor_registry as fr
from homeassistant.util.dt import utcnow
from tests.common import ANY, async_capture_events, flush_store
from tests.common import async_capture_events, flush_store
async def test_list_floors(floor_registry: fr.FloorRegistry) -> None:
@ -43,7 +43,6 @@ async def test_create_floor(
level=1,
created_at=utcnow(),
modified_at=utcnow(),
normalized_name=ANY,
)
assert len(floor_registry.floors) == 1
@ -145,7 +144,6 @@ async def test_update_floor(
level=None,
created_at=created_at,
modified_at=created_at,
normalized_name=ANY,
)
assert len(floor_registry.floors) == 1
@ -169,7 +167,6 @@ async def test_update_floor(
level=2,
created_at=created_at,
modified_at=modified_at,
normalized_name=ANY,
)
assert len(floor_registry.floors) == 1

View File

@ -16,7 +16,7 @@ from homeassistant.helpers import (
)
from homeassistant.util.dt import utcnow
from tests.common import ANY, MockConfigEntry, async_capture_events, flush_store
from tests.common import MockConfigEntry, async_capture_events, flush_store
async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
@ -46,7 +46,6 @@ async def test_create_label(
description="This label is for testing",
created_at=utcnow(),
modified_at=utcnow(),
normalized_name=ANY,
)
assert len(label_registry.labels) == 1
@ -147,7 +146,6 @@ async def test_update_label(
description=None,
created_at=created_at,
modified_at=created_at,
normalized_name=ANY,
)
modified_at = datetime.fromisoformat("2024-02-01T01:00:00+00:00")
@ -169,7 +167,6 @@ async def test_update_label(
description="Updated description",
created_at=created_at,
modified_at=modified_at,
normalized_name=ANY,
)
assert len(label_registry.labels) == 1

View File

@ -26,18 +26,14 @@ def test_registry_items(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
) -> None:
"""Test registry items."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
entry = NormalizedNameBaseRegistryEntry(name="Hello World")
registry_items["key"] = entry
assert registry_items["key"] == entry
assert list(registry_items.values()) == [entry]
assert registry_items.get_by_name("Hello World") == entry
# test update entry
entry2 = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
entry2 = NormalizedNameBaseRegistryEntry(name="Hello World 2")
registry_items["key"] = entry2
assert registry_items["key"] == entry2
assert list(registry_items.values()) == [entry2]
@ -53,16 +49,12 @@ def test_key_already_in_use(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
) -> None:
"""Test key already in use."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
entry = NormalizedNameBaseRegistryEntry(name="Hello World")
registry_items["key"] = entry
# should raise ValueError if we update a
# key with a entry with the same normalized name
entry = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
entry = NormalizedNameBaseRegistryEntry(name="Hello World 2")
registry_items["key2"] = entry
with pytest.raises(ValueError):
registry_items["key"] = entry

View File

@ -119,7 +119,6 @@ def floor_area_mock(hass: HomeAssistant) -> None:
id="test-area",
name="Test area",
aliases={},
normalized_name="test-area",
floor_id="test-floor",
icon=None,
picture=None,
@ -128,7 +127,6 @@ def floor_area_mock(hass: HomeAssistant) -> None:
id="area-a",
name="Area A",
aliases={},
normalized_name="area-a",
floor_id="floor-a",
icon=None,
picture=None,
@ -282,7 +280,6 @@ def label_mock(hass: HomeAssistant) -> None:
id="area-with-labels",
name="Area with labels",
aliases={},
normalized_name="with_labels",
floor_id=None,
icon=None,
labels={"label_area"},
@ -292,7 +289,6 @@ def label_mock(hass: HomeAssistant) -> None:
id="area-no-labels",
name="Area without labels",
aliases={},
normalized_name="without_labels",
floor_id=None,
icon=None,
labels=set(),