Add created_at/modified_at to label registry (#122078)

This commit is contained in:
Robert Resch 2024-07-17 16:36:14 +02:00 committed by GitHub
parent 8ae4c4445d
commit 10c084c6e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 172 additions and 63 deletions

View File

@ -46,8 +46,8 @@ class _AreaStoreData(TypedDict):
labels: list[str] labels: list[str]
name: str name: str
picture: str | None picture: str | None
created_at: datetime created_at: str
modified_at: datetime modified_at: str
class AreasRegistryStoreData(TypedDict): class AreasRegistryStoreData(TypedDict):
@ -87,8 +87,8 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
"labels": list(self.labels), "labels": list(self.labels),
"name": self.name, "name": self.name,
"picture": self.picture, "picture": self.picture,
"created_at": self.created_at, "created_at": self.created_at.isoformat(),
"modified_at": self.modified_at, "modified_at": self.modified_at.isoformat(),
} }
) )
) )
@ -134,7 +134,9 @@ class AreaRegistryStore(Store[AreasRegistryStoreData]):
if old_minor_version < 7: if old_minor_version < 7:
# Version 1.7 adds created_at and modiefied_at # Version 1.7 adds created_at and modiefied_at
for area in old_data["areas"]: for area in old_data["areas"]:
area["created_at"] = area["modified_at"] = utc_from_timestamp(0) area["created_at"] = area["modified_at"] = utc_from_timestamp(
0
).isoformat()
if old_major_version > 1: if old_major_version > 1:
raise NotImplementedError raise NotImplementedError
@ -374,8 +376,8 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
name=area["name"], name=area["name"],
normalized_name=normalized_name, normalized_name=normalized_name,
picture=area["picture"], picture=area["picture"],
created_at=area["created_at"], created_at=datetime.fromisoformat(area["created_at"]),
modified_at=area["modified_at"], modified_at=datetime.fromisoformat(area["modified_at"]),
) )
self.areas = areas self.areas = areas
@ -394,8 +396,8 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"labels": list(entry.labels), "labels": list(entry.labels),
"name": entry.name, "name": entry.name,
"picture": entry.picture, "picture": entry.picture,
"created_at": entry.created_at, "created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at, "modified_at": entry.modified_at.isoformat(),
} }
for entry in self.areas.values() for entry in self.areas.values()
] ]

View File

@ -41,8 +41,8 @@ class _FloorStoreData(TypedDict):
icon: str | None icon: str | None
level: int | None level: int | None
name: str name: str
created_at: datetime created_at: str
modified_at: datetime modified_at: str
class FloorRegistryStoreData(TypedDict): class FloorRegistryStoreData(TypedDict):
@ -88,7 +88,9 @@ class FloorRegistryStore(Store[FloorRegistryStoreData]):
if old_minor_version < 2: if old_minor_version < 2:
# Version 1.2 implements migration and adds created_at and modified_at # Version 1.2 implements migration and adds created_at and modified_at
for floor in old_data["floors"]: for floor in old_data["floors"]:
floor["created_at"] = floor["modified_at"] = utc_from_timestamp(0) floor["created_at"] = floor["modified_at"] = utc_from_timestamp(
0
).isoformat()
return old_data # type: ignore[return-value] return old_data # type: ignore[return-value]
@ -250,8 +252,8 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
name=floor["name"], name=floor["name"],
level=floor["level"], level=floor["level"],
normalized_name=normalized_name, normalized_name=normalized_name,
created_at=floor["created_at"], created_at=datetime.fromisoformat(floor["created_at"]),
modified_at=floor["modified_at"], modified_at=datetime.fromisoformat(floor["modified_at"]),
) )
self.floors = floors self.floors = floors
@ -268,8 +270,8 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"icon": entry.icon, "icon": entry.icon,
"level": entry.level, "level": entry.level,
"name": entry.name, "name": entry.name,
"created_at": entry.created_at, "created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at, "modified_at": entry.modified_at.isoformat(),
} }
for entry in self.floors.values() for entry in self.floors.values()
] ]

View File

@ -5,10 +5,12 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, TypedDict from datetime import datetime
from typing import Any, Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.dt import utc_from_timestamp, utcnow
from homeassistant.util.event_type import EventType from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
@ -28,6 +30,7 @@ EVENT_LABEL_REGISTRY_UPDATED: EventType[EventLabelRegistryUpdatedData] = EventTy
) )
STORAGE_KEY = "core.label_registry" STORAGE_KEY = "core.label_registry"
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 2
class _LabelStoreData(TypedDict): class _LabelStoreData(TypedDict):
@ -38,6 +41,8 @@ class _LabelStoreData(TypedDict):
icon: str | None icon: str | None
label_id: str label_id: str
name: str name: str
created_at: str
modified_at: str
class LabelRegistryStoreData(TypedDict): class LabelRegistryStoreData(TypedDict):
@ -66,6 +71,30 @@ class LabelEntry(NormalizedNameBaseRegistryEntry):
icon: str | None = None icon: str | None = None
class LabelRegistryStore(Store[LabelRegistryStoreData]):
"""Store label registry data."""
async def _async_migrate_func(
self,
old_major_version: int,
old_minor_version: int,
old_data: dict[str, list[dict[str, Any]]],
) -> LabelRegistryStoreData:
"""Migrate to the new version."""
if old_major_version > STORAGE_VERSION_MAJOR:
raise ValueError("Can't migrate to future version")
if old_major_version == 1:
if old_minor_version < 2:
# Version 1.2 implements migration and adds created_at and modified_at
for label in old_data["labels"]:
label["created_at"] = label["modified_at"] = utc_from_timestamp(
0
).isoformat()
return old_data # type: ignore[return-value]
class LabelRegistry(BaseRegistry[LabelRegistryStoreData]): class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"""Class to hold a registry of labels.""" """Class to hold a registry of labels."""
@ -75,11 +104,12 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the label registry.""" """Initialize the label registry."""
self.hass = hass self.hass = hass
self._store = Store( self._store = LabelRegistryStore(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
STORAGE_KEY, STORAGE_KEY,
atomic_writes=True, atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
) )
@callback @callback
@ -175,7 +205,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
) -> LabelEntry: ) -> LabelEntry:
"""Update name of label.""" """Update name of label."""
old = self.labels[label_id] old = self.labels[label_id]
changes = { changes: dict[str, Any] = {
attr_name: value attr_name: value
for attr_name, value in ( for attr_name, value in (
("color", color), ("color", color),
@ -192,8 +222,10 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
if not changes: if not changes:
return old return old
changes["modified_at"] = utcnow()
self.hass.verify_event_loop_thread("label_registry.async_update") self.hass.verify_event_loop_thread("label_registry.async_update")
new = self.labels[label_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type] new = self.labels[label_id] = dataclasses.replace(old, **changes)
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire_internal( self.hass.bus.async_fire_internal(
@ -221,6 +253,8 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
label_id=label["label_id"], label_id=label["label_id"],
name=label["name"], name=label["name"],
normalized_name=normalized_name, normalized_name=normalized_name,
created_at=datetime.fromisoformat(label["created_at"]),
modified_at=datetime.fromisoformat(label["modified_at"]),
) )
self.labels = labels self.labels = labels
@ -237,6 +271,8 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"icon": entry.icon, "icon": entry.icon,
"label_id": entry.label_id, "label_id": entry.label_id,
"name": entry.name, "name": entry.name,
"created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at.isoformat(),
} }
for entry in self.labels.values() for entry in self.labels.values()
] ]

View File

@ -264,15 +264,22 @@ async def test_update_floor_with_normalized_name_already_in_use(
async def test_load_floors( async def test_load_floors(
hass: HomeAssistant, floor_registry: fr.FloorRegistry hass: HomeAssistant,
floor_registry: fr.FloorRegistry,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Make sure that we can load/save data correctly.""" """Make sure that we can load/save data correctly."""
floor1_created = datetime.fromisoformat("2024-01-01T00:00:00+00:00")
freezer.move_to(floor1_created)
floor1 = floor_registry.async_create( floor1 = floor_registry.async_create(
"First floor", "First floor",
icon="mdi:home-floor-1", icon="mdi:home-floor-1",
aliases={"first", "ground"}, aliases={"first", "ground"},
level=1, level=1,
) )
floor2_created = datetime.fromisoformat("2024-02-01T00:00:00+00:00")
freezer.move_to(floor2_created)
floor2 = floor_registry.async_create( floor2 = floor_registry.async_create(
"Second floor", "Second floor",
icon="mdi:home-floor-2", icon="mdi:home-floor-2",
@ -290,20 +297,10 @@ async def test_load_floors(
assert list(floor_registry.floors) == list(registry2.floors) assert list(floor_registry.floors) == list(registry2.floors)
floor1_registry2 = registry2.async_get_floor_by_name("First floor") floor1_registry2 = registry2.async_get_floor_by_name("First floor")
assert floor1_registry2.floor_id == floor1.floor_id assert floor1_registry2 == floor1
assert floor1_registry2.name == floor1.name
assert floor1_registry2.icon == floor1.icon
assert floor1_registry2.aliases == floor1.aliases
assert floor1_registry2.level == floor1.level
assert floor1_registry2.normalized_name == floor1.normalized_name
floor2_registry2 = registry2.async_get_floor_by_name("Second floor") floor2_registry2 = registry2.async_get_floor_by_name("Second floor")
assert floor2_registry2.floor_id == floor2.floor_id assert floor2_registry2 == floor2
assert floor2_registry2.name == floor2.name
assert floor2_registry2.icon == floor2.icon
assert floor2_registry2.aliases == floor2.aliases
assert floor2_registry2.level == floor2.level
assert floor2_registry2.normalized_name == floor2.normalized_name
@pytest.mark.parametrize("load_registries", [False]) @pytest.mark.parametrize("load_registries", [False])

View File

@ -1,9 +1,11 @@
"""Tests for the Label Registry.""" """Tests for the Label Registry."""
from datetime import datetime
from functools import partial from functools import partial
import re import re
from typing import Any from typing import Any
from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -12,8 +14,9 @@ from homeassistant.helpers import (
entity_registry as er, entity_registry as er,
label_registry as lr, label_registry as lr,
) )
from homeassistant.util.dt import utcnow
from tests.common import MockConfigEntry, async_capture_events, flush_store from tests.common import ANY, MockConfigEntry, async_capture_events, flush_store
async def test_list_labels(label_registry: lr.LabelRegistry) -> None: async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
@ -22,6 +25,7 @@ async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
assert len(list(labels)) == len(label_registry.labels) assert len(list(labels)) == len(label_registry.labels)
@pytest.mark.usefixtures("freezer")
async def test_create_label( async def test_create_label(
hass: HomeAssistant, label_registry: lr.LabelRegistry hass: HomeAssistant, label_registry: lr.LabelRegistry
) -> None: ) -> None:
@ -34,11 +38,16 @@ async def test_create_label(
description="This label is for testing", description="This label is for testing",
) )
assert label.label_id == "my_label" assert label == lr.LabelEntry(
assert label.name == "My Label" label_id="my_label",
assert label.color == "#FF0000" name="My Label",
assert label.icon == "mdi:test" color="#FF0000",
assert label.description == "This label is for testing" icon="mdi:test",
description="This label is for testing",
created_at=utcnow(),
modified_at=utcnow(),
normalized_name=ANY,
)
assert len(label_registry.labels) == 1 assert len(label_registry.labels) == 1
@ -119,19 +128,30 @@ async def test_delete_non_existing_label(label_registry: lr.LabelRegistry) -> No
async def test_update_label( async def test_update_label(
hass: HomeAssistant, label_registry: lr.LabelRegistry hass: HomeAssistant,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Make sure that we can update labels.""" """Make sure that we can update labels."""
created_at = datetime.fromisoformat("2024-01-01T01:00:00+00:00")
freezer.move_to(created_at)
update_events = async_capture_events(hass, lr.EVENT_LABEL_REGISTRY_UPDATED) update_events = async_capture_events(hass, lr.EVENT_LABEL_REGISTRY_UPDATED)
label = label_registry.async_create("Mock") label = label_registry.async_create("Mock")
assert len(label_registry.labels) == 1 assert len(label_registry.labels) == 1
assert label.label_id == "mock" assert label == lr.LabelEntry(
assert label.name == "Mock" label_id="mock",
assert label.color is None name="Mock",
assert label.icon is None color=None,
assert label.description is None icon=None,
description=None,
created_at=created_at,
modified_at=created_at,
normalized_name=ANY,
)
modified_at = datetime.fromisoformat("2024-02-01T01:00:00+00:00")
freezer.move_to(modified_at)
updated_label = label_registry.async_update( updated_label = label_registry.async_update(
label.label_id, label.label_id,
name="Updated", name="Updated",
@ -141,12 +161,16 @@ async def test_update_label(
) )
assert updated_label != label assert updated_label != label
assert updated_label.label_id == "mock" assert updated_label == lr.LabelEntry(
assert updated_label.name == "Updated" label_id="mock",
assert updated_label.color == "#FFFFFF" name="Updated",
assert updated_label.icon == "mdi:update" color="#FFFFFF",
assert updated_label.description == "Updated description" icon="mdi:update",
description="Updated description",
created_at=created_at,
modified_at=modified_at,
normalized_name=ANY,
)
assert len(label_registry.labels) == 1 assert len(label_registry.labels) == 1
await hass.async_block_till_done() await hass.async_block_till_done()
@ -242,15 +266,21 @@ async def test_update_label_with_normalized_name_already_in_use(
async def test_load_labels( async def test_load_labels(
hass: HomeAssistant, label_registry: lr.LabelRegistry hass: HomeAssistant,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Make sure that we can load/save data correctly.""" """Make sure that we can load/save data correctly."""
label1_created = datetime.fromisoformat("2024-01-01T00:00:00+00:00")
freezer.move_to(label1_created)
label1 = label_registry.async_create( label1 = label_registry.async_create(
"Label One", "Label One",
color="#FF000", color="#FF000",
icon="mdi:one", icon="mdi:one",
description="This label is label one", description="This label is label one",
) )
label2_created = datetime.fromisoformat("2024-02-01T00:00:00+00:00")
freezer.move_to(label2_created)
label2 = label_registry.async_create( label2 = label_registry.async_create(
"Label Two", "Label Two",
color="#000FF", color="#000FF",
@ -268,19 +298,10 @@ async def test_load_labels(
assert list(label_registry.labels) == list(registry2.labels) assert list(label_registry.labels) == list(registry2.labels)
label1_registry2 = registry2.async_get_label_by_name("Label One") label1_registry2 = registry2.async_get_label_by_name("Label One")
assert label1_registry2.label_id == label1.label_id assert label1_registry2 == label1
assert label1_registry2.name == label1.name
assert label1_registry2.color == label1.color
assert label1_registry2.description == label1.description
assert label1_registry2.icon == label1.icon
assert label1_registry2.normalized_name == label1.normalized_name
label2_registry2 = registry2.async_get_label_by_name("Label Two") label2_registry2 = registry2.async_get_label_by_name("Label Two")
assert label2_registry2.name == label2.name assert label2_registry2 == label2
assert label2_registry2.color == label2.color
assert label2_registry2.description == label2.description
assert label2_registry2.icon == label2.icon
assert label2_registry2.normalized_name == label2.normalized_name
@pytest.mark.parametrize("load_registries", [False]) @pytest.mark.parametrize("load_registries", [False])
@ -298,6 +319,8 @@ async def test_loading_label_from_storage(
"icon": "mdi:test", "icon": "mdi:test",
"label_id": "one", "label_id": "one",
"name": "One", "name": "One",
"created_at": "2024-01-01T00:00:00+00:00",
"modified_at": "2024-02-01T00:00:00+00:00",
} }
] ]
}, },
@ -489,3 +512,52 @@ async def test_async_update_thread_safety(
await hass.async_add_executor_job( await hass.async_add_executor_job(
partial(label_registry.async_update, any_label.label_id, name="new name") partial(label_registry.async_update, any_label.label_id, name="new name")
) )
@pytest.mark.parametrize("load_registries", [False])
async def test_migration_from_1_1(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test migration from version 1.1."""
hass_storage[lr.STORAGE_KEY] = {
"version": 1,
"data": {
"labels": [
{
"color": None,
"description": None,
"icon": None,
"label_id": "12345A",
"name": "mock",
}
]
},
}
await lr.async_load(hass)
registry = lr.async_get(hass)
# Test data was loaded
entry = registry.async_get_label_by_name("mock")
assert entry.label_id == "12345A"
# Check we store migrated data
await flush_store(registry._store)
assert hass_storage[lr.STORAGE_KEY] == {
"version": lr.STORAGE_VERSION_MAJOR,
"minor_version": lr.STORAGE_VERSION_MINOR,
"key": lr.STORAGE_KEY,
"data": {
"labels": [
{
"color": None,
"description": None,
"icon": None,
"label_id": "12345A",
"name": "mock",
"created_at": "1970-01-01T00:00:00+00:00",
"modified_at": "1970-01-01T00:00:00+00:00",
}
]
},
}