Allow storing other items than untyped dict in StorageCollection (#90932)

Allow storing other items than untyped dict in StorageCollection
This commit is contained in:
Erik Montnemery 2023-04-06 16:57:00 +02:00 committed by GitHub
parent 8025fbf398
commit 59a02cd08c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 134 additions and 91 deletions

View File

@ -75,7 +75,7 @@ class AuthorizationServer:
token_url: str
class ApplicationCredentialsStorageCollection(collection.StorageCollection):
class ApplicationCredentialsStorageCollection(collection.DictStorageCollection):
"""Application credential collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -94,7 +94,7 @@ class ApplicationCredentialsStorageCollection(collection.StorageCollection):
return f"{info[CONF_DOMAIN]}.{info[CONF_CLIENT_ID]}"
async def _update_data(
self, data: dict[str, str], update_data: dict[str, str]
self, item: dict[str, str], update_data: dict[str, str]
) -> dict[str, str]:
"""Return a new updated data object."""
raise ValueError("Updates not supported")

View File

@ -139,7 +139,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class CounterStorageCollection(collection.StorageCollection):
class CounterStorageCollection(collection.DictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -153,10 +153,10 @@ class CounterStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
class Counter(collection.CollectionEntity, RestoreEntity):

View File

@ -57,7 +57,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class ImageStorageCollection(collection.StorageCollection):
class ImageStorageCollection(collection.DictStorageCollection):
"""Image collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -125,11 +125,11 @@ class ImageStorageCollection(collection.StorageCollection):
async def _update_data(
self,
data: dict[str, Any],
item: dict[str, Any],
update_data: dict[str, Any],
) -> dict[str, Any]:
"""Return a new updated data object."""
return {**data, **self.UPDATE_SCHEMA(update_data)}
return {**item, **self.UPDATE_SCHEMA(update_data)}
async def _change_listener(
self,

View File

@ -65,7 +65,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
class InputBooleanStorageCollection(collection.StorageCollection):
class InputBooleanStorageCollection(collection.DictStorageCollection):
"""Input boolean collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -79,10 +79,10 @@ class InputBooleanStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
@bind_hass

View File

@ -56,7 +56,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
class InputButtonStorageCollection(collection.StorageCollection):
class InputButtonStorageCollection(collection.DictStorageCollection):
"""Input button collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -70,10 +70,10 @@ class InputButtonStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return cast(str, info[CONF_NAME])
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -203,7 +203,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class DateTimeStorageCollection(collection.StorageCollection):
class DateTimeStorageCollection(collection.DictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, has_date_or_time))
@ -217,10 +217,10 @@ class DateTimeStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
class InputDatetime(collection.CollectionEntity, RestoreEntity):

View File

@ -170,7 +170,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class NumberStorageCollection(collection.StorageCollection):
class NumberStorageCollection(collection.DictStorageCollection):
"""Input storage based collection."""
SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number))
@ -184,7 +184,7 @@ class NumberStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data.
A past bug caused frontend to add initial value to all input numbers.
@ -200,10 +200,10 @@ class NumberStorageCollection(collection.StorageCollection):
return data
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
class InputNumber(collection.CollectionEntity, RestoreEntity):

View File

@ -231,7 +231,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class InputSelectStorageCollection(collection.StorageCollection):
class InputSelectStorageCollection(collection.DictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_select))
@ -246,11 +246,11 @@ class InputSelectStorageCollection(collection.StorageCollection):
return cast(str, info[CONF_NAME])
async def _update_data(
self, data: dict[str, Any], update_data: dict[str, Any]
self, item: dict[str, Any], update_data: dict[str, Any]
) -> dict[str, Any]:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
class InputSelect(collection.CollectionEntity, SelectEntity, RestoreEntity):

View File

@ -164,7 +164,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class InputTextStorageCollection(collection.StorageCollection):
class InputTextStorageCollection(collection.DictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))
@ -178,10 +178,10 @@ class InputTextStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data
return {CONF_ID: item[CONF_ID]} | update_data
class InputText(collection.CollectionEntity, RestoreEntity):

View File

@ -6,7 +6,6 @@ import logging
import os
from pathlib import Path
import time
from typing import cast
import voluptuous as vol
@ -218,7 +217,7 @@ def _config_info(mode, config):
}
class DashboardsCollection(collection.StorageCollection):
class DashboardsCollection(collection.DictStorageCollection):
"""Collection of dashboards."""
CREATE_SCHEMA = vol.Schema(STORAGE_DASHBOARD_CREATE_FIELDS)
@ -230,10 +229,10 @@ class DashboardsCollection(collection.StorageCollection):
storage.Store(hass, DASHBOARDS_STORAGE_VERSION, DASHBOARDS_STORAGE_KEY),
)
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data."""
if (data := await self.store.async_load()) is None:
return cast(dict | None, data)
return data
updated = False
@ -245,7 +244,7 @@ class DashboardsCollection(collection.StorageCollection):
if updated:
await self.store.async_save(data)
return cast(dict | None, data)
return data
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
@ -262,10 +261,10 @@ class DashboardsCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_URL_PATH]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
updated = {**data, **update_data}
updated = {**item, **update_data}
if CONF_ICON in updated and updated[CONF_ICON] is None:
updated.pop(CONF_ICON)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import cast
from typing import Any
import uuid
import voluptuous as vol
@ -45,7 +45,7 @@ class ResourceYAMLCollection:
return self.data
class ResourceStorageCollection(collection.StorageCollection):
class ResourceStorageCollection(collection.DictStorageCollection):
"""Collection to store resources."""
loaded = False
@ -67,10 +67,10 @@ class ResourceStorageCollection(collection.StorageCollection):
return {"resources": len(self.async_items() or [])}
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data."""
if (data := await self.store.async_load()) is not None:
return cast(dict | None, data)
if (store_data := await self.store.async_load()) is not None:
return store_data
# Import it from config.
try:
@ -82,20 +82,20 @@ class ResourceStorageCollection(collection.StorageCollection):
return None
# Remove it from config and save both resources + config
data = conf[CONF_RESOURCES]
resources: list[dict[str, Any]] = conf[CONF_RESOURCES]
try:
vol.Schema([RESOURCE_SCHEMA])(data)
vol.Schema([RESOURCE_SCHEMA])(resources)
except vol.Invalid as err:
_LOGGER.warning("Resource import failed. Data invalid: %s", err)
return None
conf.pop(CONF_RESOURCES)
for item in data:
for item in resources:
item[CONF_ID] = uuid.uuid4().hex
data = {"items": data}
data: collection.SerializedStorageCollection = {"items": resources}
await self.store.async_save(data)
await self.ll_config.async_save(conf)
@ -113,7 +113,7 @@ class ResourceStorageCollection(collection.StorageCollection):
"""Return unique ID."""
return uuid.uuid4().hex
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
if not self.loaded:
await self.async_load()
@ -123,4 +123,4 @@ class ResourceStorageCollection(collection.StorageCollection):
if CONF_RESOURCE_TYPE_WS in update_data:
update_data[CONF_TYPE] = update_data.pop(CONF_RESOURCE_TYPE_WS)
return {**data, **update_data}
return {**item, **update_data}

View File

@ -188,7 +188,7 @@ class PersonStore(Store):
return {"items": old_data["persons"]}
class PersonStorageCollection(collection.StorageCollection):
class PersonStorageCollection(collection.DictStorageCollection):
"""Person collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -204,7 +204,7 @@ class PersonStorageCollection(collection.StorageCollection):
super().__init__(store, id_manager)
self.yaml_collection = yaml_collection
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data.
A past bug caused onboarding to create invalid person objects.
@ -270,16 +270,16 @@ class PersonStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
user_id = update_data.get(CONF_USER_ID)
if user_id is not None and user_id != data.get(CONF_USER_ID):
if user_id is not None and user_id != item.get(CONF_USER_ID):
await self._validate_user_id(user_id)
return {**data, **update_data}
return {**item, **update_data}
async def _validate_user_id(self, user_id):
"""Validate the used user_id."""

View File

@ -20,8 +20,9 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.collection import (
CollectionEntity,
DictStorageCollection,
IDManager,
StorageCollection,
SerializedStorageCollection,
StorageCollectionWebsocket,
YamlCollection,
sync_entity_lifecycle,
@ -208,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class ScheduleStorageCollection(StorageCollection):
class ScheduleStorageCollection(DictStorageCollection):
"""Schedules stored in storage."""
SCHEMA = vol.Schema(BASE_SCHEMA | STORAGE_SCHEDULE_SCHEMA)
@ -224,12 +225,12 @@ class ScheduleStorageCollection(StorageCollection):
name: str = info[CONF_NAME]
return name
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
self.SCHEMA(update_data)
return data | update_data
return item | update_data
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> SerializedStorageCollection | None:
"""Load the data."""
if data := await super()._async_load_data():
data["items"] = [STORAGE_SCHEMA(item) for item in data["items"]]

View File

@ -59,7 +59,7 @@ class TagIDManager(collection.IDManager):
return suggestion
class TagStorageCollection(collection.StorageCollection):
class TagStorageCollection(collection.DictStorageCollection):
"""Tag collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -80,9 +80,9 @@ class TagStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[TAG_ID]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
data = {**data, **self.UPDATE_SCHEMA(update_data)}
data = {**item, **self.UPDATE_SCHEMA(update_data)}
# make last_scanned JSON serializeable
if LAST_SCANNED in update_data:
data[LAST_SCANNED] = data[LAST_SCANNED].isoformat()

View File

@ -162,7 +162,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class TimerStorageCollection(collection.StorageCollection):
class TimerStorageCollection(collection.DictStorageCollection):
"""Timer storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -179,9 +179,9 @@ class TimerStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
data = {CONF_ID: data[CONF_ID]} | self.CREATE_UPDATE_SCHEMA(update_data)
data = {CONF_ID: item[CONF_ID]} | self.CREATE_UPDATE_SCHEMA(update_data)
# make duration JSON serializeable
if CONF_DURATION in update_data:
data[CONF_DURATION] = _format_timedelta(data[CONF_DURATION])

View File

@ -163,7 +163,7 @@ def in_zone(zone: State, latitude: float, longitude: float, radius: float = 0) -
return zone_dist - radius < cast(float, zone.attributes[ATTR_RADIUS])
class ZoneStorageCollection(collection.StorageCollection):
class ZoneStorageCollection(collection.DictStorageCollection):
"""Zone collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -178,10 +178,10 @@ class ZoneStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return cast(str, info[CONF_NAME])
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
return {**data, **update_data}
return {**item, **update_data}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from itertools import groupby
import logging
from operator import attrgetter
from typing import Any, cast
from typing import Any, Generic, TypedDict, TypeVar
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -32,6 +32,8 @@ CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
_T = TypeVar("_T")
@dataclass
class CollectionChangeSet:
@ -121,20 +123,20 @@ class CollectionEntity(Entity):
"""Handle updated configuration."""
class ObservableCollection(ABC):
class ObservableCollection(ABC, Generic[_T]):
"""Base collection type that can be observed."""
def __init__(self, id_manager: IDManager | None) -> None:
"""Initialize the base collection."""
self.id_manager = id_manager or IDManager()
self.data: dict[str, dict] = {}
self.data: dict[str, _T] = {}
self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> list[dict]:
def async_items(self) -> list[_T]:
"""Return list of items in collection."""
return list(self.data.values())
@ -169,7 +171,7 @@ class ObservableCollection(ABC):
)
class YamlCollection(ObservableCollection):
class YamlCollection(ObservableCollection[dict]):
"""Offer a collection based on static data."""
def __init__(
@ -218,12 +220,18 @@ class YamlCollection(ObservableCollection):
await self.notify_changes(change_sets)
class StorageCollection(ObservableCollection, ABC):
class SerializedStorageCollection(TypedDict):
"""Serialized storage collection."""
items: list[dict[str, Any]]
class StorageCollection(ObservableCollection[_T], ABC):
"""Offer a CRUD interface on top of JSON storage."""
def __init__(
self,
store: Store,
store: Store[SerializedStorageCollection],
id_manager: IDManager | None = None,
) -> None:
"""Initialize the storage collection."""
@ -242,9 +250,9 @@ class StorageCollection(ObservableCollection, ABC):
"""Home Assistant object."""
return self.store.hass
async def _async_load_data(self) -> dict | None:
async def _async_load_data(self) -> SerializedStorageCollection | None:
"""Load the data."""
return cast(dict | None, await self.store.async_load())
return await self.store.async_load()
async def async_load(self) -> None:
"""Load the storage Manager."""
@ -254,7 +262,7 @@ class StorageCollection(ObservableCollection, ABC):
raw_storage = {"items": []}
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item
self.data[item[CONF_ID]] = self._deserialize_item(item)
await self.notify_changes(
[
@ -273,21 +281,35 @@ class StorageCollection(ObservableCollection, ABC):
"""Suggest an ID based on the config."""
@abstractmethod
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
async def _update_data(self, item: _T, update_data: dict) -> _T:
"""Return a new updated item."""
async def async_create_item(self, data: dict) -> dict:
@abstractmethod
def _create_item(self, item_id: str, data: dict) -> _T:
"""Create an item from validated config."""
@abstractmethod
def _deserialize_item(self, data: dict) -> _T:
"""Create an item from its serialized representation."""
@abstractmethod
def _serialize_item(self, item_id: str, item: _T) -> dict:
"""Return the serialized representation of an item.
The serialized representation must include the item_id in the "id" key.
"""
async def async_create_item(self, data: dict) -> _T:
"""Create a new item."""
item = await self._process_create_data(data)
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
validated_data = await self._process_create_data(data)
item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data))
item = self._create_item(item_id, validated_data)
self.data[item_id] = item
self._async_schedule_save()
await self.notify_changes(
[CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)]
)
await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)])
return item
async def async_update_item(self, item_id: str, updates: dict) -> dict:
async def async_update_item(self, item_id: str, updates: dict) -> _T:
"""Update item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
@ -320,13 +342,34 @@ class StorageCollection(ObservableCollection, ABC):
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
"""Schedule saving the collection."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"items": list(self.data.values())}
def _data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible date for storing to file."""
return {
"items": [
self._serialize_item(item_id, item)
for item_id, item in self.data.items()
]
}
class DictStorageCollection(StorageCollection[dict]):
"""A specialized StorageCollection where the items are untyped dicts."""
def _create_item(self, item_id: str, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return {CONF_ID: item_id} | data
def _deserialize_item(self, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return data
def _serialize_item(self, item_id: str, item: dict) -> dict:
"""Return the serialized representation of an item."""
return item
class IDLessCollection(YamlCollection):

View File

@ -82,7 +82,7 @@ class MockObservableCollection(collection.ObservableCollection):
return entity_class.from_storage(config)
class MockStorageCollection(collection.StorageCollection):
class MockStorageCollection(collection.DictStorageCollection):
"""Mock storage collection."""
async def _process_create_data(self, data: dict) -> dict:
@ -96,9 +96,9 @@ class MockStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config."""
return info["name"]
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
return {**data, **update_data}
return {**item, **update_data}
def test_id_manager() -> None: