Add created_at/modified_at to label registry (#122078)

This commit is contained in:
Robert Resch 2024-07-17 16:36:14 +02:00 committed by GitHub
parent 8ae4c4445d
commit 10c084c6e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 172 additions and 63 deletions

View File

@ -46,8 +46,8 @@ class _AreaStoreData(TypedDict):
labels: list[str]
name: str
picture: str | None
created_at: datetime
modified_at: datetime
created_at: str
modified_at: str
class AreasRegistryStoreData(TypedDict):
@ -87,8 +87,8 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
"labels": list(self.labels),
"name": self.name,
"picture": self.picture,
"created_at": self.created_at,
"modified_at": self.modified_at,
"created_at": self.created_at.isoformat(),
"modified_at": self.modified_at.isoformat(),
}
)
)
@ -134,7 +134,9 @@ class AreaRegistryStore(Store[AreasRegistryStoreData]):
if old_minor_version < 7:
# Version 1.7 adds created_at and modiefied_at
for area in old_data["areas"]:
area["created_at"] = area["modified_at"] = utc_from_timestamp(0)
area["created_at"] = area["modified_at"] = utc_from_timestamp(
0
).isoformat()
if old_major_version > 1:
raise NotImplementedError
@ -374,8 +376,8 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
name=area["name"],
normalized_name=normalized_name,
picture=area["picture"],
created_at=area["created_at"],
modified_at=area["modified_at"],
created_at=datetime.fromisoformat(area["created_at"]),
modified_at=datetime.fromisoformat(area["modified_at"]),
)
self.areas = areas
@ -394,8 +396,8 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"labels": list(entry.labels),
"name": entry.name,
"picture": entry.picture,
"created_at": entry.created_at,
"modified_at": entry.modified_at,
"created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at.isoformat(),
}
for entry in self.areas.values()
]

View File

@ -41,8 +41,8 @@ class _FloorStoreData(TypedDict):
icon: str | None
level: int | None
name: str
created_at: datetime
modified_at: datetime
created_at: str
modified_at: str
class FloorRegistryStoreData(TypedDict):
@ -88,7 +88,9 @@ class FloorRegistryStore(Store[FloorRegistryStoreData]):
if old_minor_version < 2:
# Version 1.2 implements migration and adds created_at and modified_at
for floor in old_data["floors"]:
floor["created_at"] = floor["modified_at"] = utc_from_timestamp(0)
floor["created_at"] = floor["modified_at"] = utc_from_timestamp(
0
).isoformat()
return old_data # type: ignore[return-value]
@ -250,8 +252,8 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
name=floor["name"],
level=floor["level"],
normalized_name=normalized_name,
created_at=floor["created_at"],
modified_at=floor["modified_at"],
created_at=datetime.fromisoformat(floor["created_at"]),
modified_at=datetime.fromisoformat(floor["modified_at"]),
)
self.floors = floors
@ -268,8 +270,8 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"icon": entry.icon,
"level": entry.level,
"name": entry.name,
"created_at": entry.created_at,
"modified_at": entry.modified_at,
"created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at.isoformat(),
}
for entry in self.floors.values()
]

View File

@ -5,10 +5,12 @@ from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import Literal, TypedDict
from datetime import datetime
from typing import Any, Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.dt import utc_from_timestamp, utcnow
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
@ -28,6 +30,7 @@ EVENT_LABEL_REGISTRY_UPDATED: EventType[EventLabelRegistryUpdatedData] = EventTy
)
STORAGE_KEY = "core.label_registry"
STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 2
class _LabelStoreData(TypedDict):
@ -38,6 +41,8 @@ class _LabelStoreData(TypedDict):
icon: str | None
label_id: str
name: str
created_at: str
modified_at: str
class LabelRegistryStoreData(TypedDict):
@ -66,6 +71,30 @@ class LabelEntry(NormalizedNameBaseRegistryEntry):
icon: str | None = None
class LabelRegistryStore(Store[LabelRegistryStoreData]):
"""Store label registry data."""
async def _async_migrate_func(
self,
old_major_version: int,
old_minor_version: int,
old_data: dict[str, list[dict[str, Any]]],
) -> LabelRegistryStoreData:
"""Migrate to the new version."""
if old_major_version > STORAGE_VERSION_MAJOR:
raise ValueError("Can't migrate to future version")
if old_major_version == 1:
if old_minor_version < 2:
# Version 1.2 implements migration and adds created_at and modified_at
for label in old_data["labels"]:
label["created_at"] = label["modified_at"] = utc_from_timestamp(
0
).isoformat()
return old_data # type: ignore[return-value]
class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"""Class to hold a registry of labels."""
@ -75,11 +104,12 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the label registry."""
self.hass = hass
self._store = Store(
self._store = LabelRegistryStore(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
)
@callback
@ -175,7 +205,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
) -> LabelEntry:
"""Update name of label."""
old = self.labels[label_id]
changes = {
changes: dict[str, Any] = {
attr_name: value
for attr_name, value in (
("color", color),
@ -192,8 +222,10 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
if not changes:
return old
changes["modified_at"] = utcnow()
self.hass.verify_event_loop_thread("label_registry.async_update")
new = self.labels[label_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type]
new = self.labels[label_id] = dataclasses.replace(old, **changes)
self.async_schedule_save()
self.hass.bus.async_fire_internal(
@ -221,6 +253,8 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
label_id=label["label_id"],
name=label["name"],
normalized_name=normalized_name,
created_at=datetime.fromisoformat(label["created_at"]),
modified_at=datetime.fromisoformat(label["modified_at"]),
)
self.labels = labels
@ -237,6 +271,8 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"icon": entry.icon,
"label_id": entry.label_id,
"name": entry.name,
"created_at": entry.created_at.isoformat(),
"modified_at": entry.modified_at.isoformat(),
}
for entry in self.labels.values()
]

View File

@ -264,15 +264,22 @@ async def test_update_floor_with_normalized_name_already_in_use(
async def test_load_floors(
hass: HomeAssistant, floor_registry: fr.FloorRegistry
hass: HomeAssistant,
floor_registry: fr.FloorRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Make sure that we can load/save data correctly."""
floor1_created = datetime.fromisoformat("2024-01-01T00:00:00+00:00")
freezer.move_to(floor1_created)
floor1 = floor_registry.async_create(
"First floor",
icon="mdi:home-floor-1",
aliases={"first", "ground"},
level=1,
)
floor2_created = datetime.fromisoformat("2024-02-01T00:00:00+00:00")
freezer.move_to(floor2_created)
floor2 = floor_registry.async_create(
"Second floor",
icon="mdi:home-floor-2",
@ -290,20 +297,10 @@ async def test_load_floors(
assert list(floor_registry.floors) == list(registry2.floors)
floor1_registry2 = registry2.async_get_floor_by_name("First floor")
assert floor1_registry2.floor_id == floor1.floor_id
assert floor1_registry2.name == floor1.name
assert floor1_registry2.icon == floor1.icon
assert floor1_registry2.aliases == floor1.aliases
assert floor1_registry2.level == floor1.level
assert floor1_registry2.normalized_name == floor1.normalized_name
assert floor1_registry2 == floor1
floor2_registry2 = registry2.async_get_floor_by_name("Second floor")
assert floor2_registry2.floor_id == floor2.floor_id
assert floor2_registry2.name == floor2.name
assert floor2_registry2.icon == floor2.icon
assert floor2_registry2.aliases == floor2.aliases
assert floor2_registry2.level == floor2.level
assert floor2_registry2.normalized_name == floor2.normalized_name
assert floor2_registry2 == floor2
@pytest.mark.parametrize("load_registries", [False])

View File

@ -1,9 +1,11 @@
"""Tests for the Label Registry."""
from datetime import datetime
from functools import partial
import re
from typing import Any
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.core import HomeAssistant
@ -12,8 +14,9 @@ from homeassistant.helpers import (
entity_registry as er,
label_registry as lr,
)
from homeassistant.util.dt import utcnow
from tests.common import MockConfigEntry, async_capture_events, flush_store
from tests.common import ANY, MockConfigEntry, async_capture_events, flush_store
async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
@ -22,6 +25,7 @@ async def test_list_labels(label_registry: lr.LabelRegistry) -> None:
assert len(list(labels)) == len(label_registry.labels)
@pytest.mark.usefixtures("freezer")
async def test_create_label(
hass: HomeAssistant, label_registry: lr.LabelRegistry
) -> None:
@ -34,11 +38,16 @@ async def test_create_label(
description="This label is for testing",
)
assert label.label_id == "my_label"
assert label.name == "My Label"
assert label.color == "#FF0000"
assert label.icon == "mdi:test"
assert label.description == "This label is for testing"
assert label == lr.LabelEntry(
label_id="my_label",
name="My Label",
color="#FF0000",
icon="mdi:test",
description="This label is for testing",
created_at=utcnow(),
modified_at=utcnow(),
normalized_name=ANY,
)
assert len(label_registry.labels) == 1
@ -119,19 +128,30 @@ async def test_delete_non_existing_label(label_registry: lr.LabelRegistry) -> No
async def test_update_label(
hass: HomeAssistant, label_registry: lr.LabelRegistry
hass: HomeAssistant,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Make sure that we can update labels."""
created_at = datetime.fromisoformat("2024-01-01T01:00:00+00:00")
freezer.move_to(created_at)
update_events = async_capture_events(hass, lr.EVENT_LABEL_REGISTRY_UPDATED)
label = label_registry.async_create("Mock")
assert len(label_registry.labels) == 1
assert label.label_id == "mock"
assert label.name == "Mock"
assert label.color is None
assert label.icon is None
assert label.description is None
assert label == lr.LabelEntry(
label_id="mock",
name="Mock",
color=None,
icon=None,
description=None,
created_at=created_at,
modified_at=created_at,
normalized_name=ANY,
)
modified_at = datetime.fromisoformat("2024-02-01T01:00:00+00:00")
freezer.move_to(modified_at)
updated_label = label_registry.async_update(
label.label_id,
name="Updated",
@ -141,12 +161,16 @@ async def test_update_label(
)
assert updated_label != label
assert updated_label.label_id == "mock"
assert updated_label.name == "Updated"
assert updated_label.color == "#FFFFFF"
assert updated_label.icon == "mdi:update"
assert updated_label.description == "Updated description"
assert updated_label == lr.LabelEntry(
label_id="mock",
name="Updated",
color="#FFFFFF",
icon="mdi:update",
description="Updated description",
created_at=created_at,
modified_at=modified_at,
normalized_name=ANY,
)
assert len(label_registry.labels) == 1
await hass.async_block_till_done()
@ -242,15 +266,21 @@ async def test_update_label_with_normalized_name_already_in_use(
async def test_load_labels(
hass: HomeAssistant, label_registry: lr.LabelRegistry
hass: HomeAssistant,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Make sure that we can load/save data correctly."""
label1_created = datetime.fromisoformat("2024-01-01T00:00:00+00:00")
freezer.move_to(label1_created)
label1 = label_registry.async_create(
"Label One",
color="#FF000",
icon="mdi:one",
description="This label is label one",
)
label2_created = datetime.fromisoformat("2024-02-01T00:00:00+00:00")
freezer.move_to(label2_created)
label2 = label_registry.async_create(
"Label Two",
color="#000FF",
@ -268,19 +298,10 @@ async def test_load_labels(
assert list(label_registry.labels) == list(registry2.labels)
label1_registry2 = registry2.async_get_label_by_name("Label One")
assert label1_registry2.label_id == label1.label_id
assert label1_registry2.name == label1.name
assert label1_registry2.color == label1.color
assert label1_registry2.description == label1.description
assert label1_registry2.icon == label1.icon
assert label1_registry2.normalized_name == label1.normalized_name
assert label1_registry2 == label1
label2_registry2 = registry2.async_get_label_by_name("Label Two")
assert label2_registry2.name == label2.name
assert label2_registry2.color == label2.color
assert label2_registry2.description == label2.description
assert label2_registry2.icon == label2.icon
assert label2_registry2.normalized_name == label2.normalized_name
assert label2_registry2 == label2
@pytest.mark.parametrize("load_registries", [False])
@ -298,6 +319,8 @@ async def test_loading_label_from_storage(
"icon": "mdi:test",
"label_id": "one",
"name": "One",
"created_at": "2024-01-01T00:00:00+00:00",
"modified_at": "2024-02-01T00:00:00+00:00",
}
]
},
@ -489,3 +512,52 @@ async def test_async_update_thread_safety(
await hass.async_add_executor_job(
partial(label_registry.async_update, any_label.label_id, name="new name")
)
@pytest.mark.parametrize("load_registries", [False])
async def test_migration_from_1_1(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test migration from version 1.1."""
hass_storage[lr.STORAGE_KEY] = {
"version": 1,
"data": {
"labels": [
{
"color": None,
"description": None,
"icon": None,
"label_id": "12345A",
"name": "mock",
}
]
},
}
await lr.async_load(hass)
registry = lr.async_get(hass)
# Test data was loaded
entry = registry.async_get_label_by_name("mock")
assert entry.label_id == "12345A"
# Check we store migrated data
await flush_store(registry._store)
assert hass_storage[lr.STORAGE_KEY] == {
"version": lr.STORAGE_VERSION_MAJOR,
"minor_version": lr.STORAGE_VERSION_MINOR,
"key": lr.STORAGE_KEY,
"data": {
"labels": [
{
"color": None,
"description": None,
"icon": None,
"label_id": "12345A",
"name": "mock",
"created_at": "1970-01-01T00:00:00+00:00",
"modified_at": "1970-01-01T00:00:00+00:00",
}
]
},
}