Allow storing other items than untyped dict in StorageCollection (#90932)

Allow storing other items than untyped dict in StorageCollection
This commit is contained in:
Erik Montnemery 2023-04-06 16:57:00 +02:00 committed by GitHub
parent 8025fbf398
commit 59a02cd08c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 134 additions and 91 deletions

View File

@ -75,7 +75,7 @@ class AuthorizationServer:
token_url: str token_url: str
class ApplicationCredentialsStorageCollection(collection.StorageCollection): class ApplicationCredentialsStorageCollection(collection.DictStorageCollection):
"""Application credential collection stored in storage.""" """Application credential collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -94,7 +94,7 @@ class ApplicationCredentialsStorageCollection(collection.StorageCollection):
return f"{info[CONF_DOMAIN]}.{info[CONF_CLIENT_ID]}" return f"{info[CONF_DOMAIN]}.{info[CONF_CLIENT_ID]}"
async def _update_data( async def _update_data(
self, data: dict[str, str], update_data: dict[str, str] self, item: dict[str, str], update_data: dict[str, str]
) -> dict[str, str]: ) -> dict[str, str]:
"""Return a new updated data object.""" """Return a new updated data object."""
raise ValueError("Updates not supported") raise ValueError("Updates not supported")

View File

@ -139,7 +139,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class CounterStorageCollection(collection.StorageCollection): class CounterStorageCollection(collection.DictStorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -153,10 +153,10 @@ class CounterStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
class Counter(collection.CollectionEntity, RestoreEntity): class Counter(collection.CollectionEntity, RestoreEntity):

View File

@ -57,7 +57,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class ImageStorageCollection(collection.StorageCollection): class ImageStorageCollection(collection.DictStorageCollection):
"""Image collection stored in storage.""" """Image collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -125,11 +125,11 @@ class ImageStorageCollection(collection.StorageCollection):
async def _update_data( async def _update_data(
self, self,
data: dict[str, Any], item: dict[str, Any],
update_data: dict[str, Any], update_data: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return a new updated data object.""" """Return a new updated data object."""
return {**data, **self.UPDATE_SCHEMA(update_data)} return {**item, **self.UPDATE_SCHEMA(update_data)}
async def _change_listener( async def _change_listener(
self, self,

View File

@ -65,7 +65,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
class InputBooleanStorageCollection(collection.StorageCollection): class InputBooleanStorageCollection(collection.DictStorageCollection):
"""Input boolean collection stored in storage.""" """Input boolean collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -79,10 +79,10 @@ class InputBooleanStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
@bind_hass @bind_hass

View File

@ -56,7 +56,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
class InputButtonStorageCollection(collection.StorageCollection): class InputButtonStorageCollection(collection.DictStorageCollection):
"""Input button collection stored in storage.""" """Input button collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -70,10 +70,10 @@ class InputButtonStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return cast(str, info[CONF_NAME]) return cast(str, info[CONF_NAME])
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -203,7 +203,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class DateTimeStorageCollection(collection.StorageCollection): class DateTimeStorageCollection(collection.DictStorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, has_date_or_time)) CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, has_date_or_time))
@ -217,10 +217,10 @@ class DateTimeStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
class InputDatetime(collection.CollectionEntity, RestoreEntity): class InputDatetime(collection.CollectionEntity, RestoreEntity):

View File

@ -170,7 +170,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class NumberStorageCollection(collection.StorageCollection): class NumberStorageCollection(collection.DictStorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number)) SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number))
@ -184,7 +184,7 @@ class NumberStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data. """Load the data.
A past bug caused frontend to add initial value to all input numbers. A past bug caused frontend to add initial value to all input numbers.
@ -200,10 +200,10 @@ class NumberStorageCollection(collection.StorageCollection):
return data return data
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.SCHEMA(update_data) update_data = self.SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
class InputNumber(collection.CollectionEntity, RestoreEntity): class InputNumber(collection.CollectionEntity, RestoreEntity):

View File

@ -231,7 +231,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class InputSelectStorageCollection(collection.StorageCollection): class InputSelectStorageCollection(collection.DictStorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_select)) CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_select))
@ -246,11 +246,11 @@ class InputSelectStorageCollection(collection.StorageCollection):
return cast(str, info[CONF_NAME]) return cast(str, info[CONF_NAME])
async def _update_data( async def _update_data(
self, data: dict[str, Any], update_data: dict[str, Any] self, item: dict[str, Any], update_data: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
class InputSelect(collection.CollectionEntity, SelectEntity, RestoreEntity): class InputSelect(collection.CollectionEntity, SelectEntity, RestoreEntity):

View File

@ -164,7 +164,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class InputTextStorageCollection(collection.StorageCollection): class InputTextStorageCollection(collection.DictStorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text)) CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))
@ -178,10 +178,10 @@ class InputTextStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.CREATE_UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {CONF_ID: data[CONF_ID]} | update_data return {CONF_ID: item[CONF_ID]} | update_data
class InputText(collection.CollectionEntity, RestoreEntity): class InputText(collection.CollectionEntity, RestoreEntity):

View File

@ -6,7 +6,6 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
import time import time
from typing import cast
import voluptuous as vol import voluptuous as vol
@ -218,7 +217,7 @@ def _config_info(mode, config):
} }
class DashboardsCollection(collection.StorageCollection): class DashboardsCollection(collection.DictStorageCollection):
"""Collection of dashboards.""" """Collection of dashboards."""
CREATE_SCHEMA = vol.Schema(STORAGE_DASHBOARD_CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(STORAGE_DASHBOARD_CREATE_FIELDS)
@ -230,10 +229,10 @@ class DashboardsCollection(collection.StorageCollection):
storage.Store(hass, DASHBOARDS_STORAGE_VERSION, DASHBOARDS_STORAGE_KEY), storage.Store(hass, DASHBOARDS_STORAGE_VERSION, DASHBOARDS_STORAGE_KEY),
) )
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data.""" """Load the data."""
if (data := await self.store.async_load()) is None: if (data := await self.store.async_load()) is None:
return cast(dict | None, data) return data
updated = False updated = False
@ -245,7 +244,7 @@ class DashboardsCollection(collection.StorageCollection):
if updated: if updated:
await self.store.async_save(data) await self.store.async_save(data)
return cast(dict | None, data) return data
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid.""" """Validate the config is valid."""
@ -262,10 +261,10 @@ class DashboardsCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_URL_PATH] return info[CONF_URL_PATH]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data) update_data = self.UPDATE_SCHEMA(update_data)
updated = {**data, **update_data} updated = {**item, **update_data}
if CONF_ICON in updated and updated[CONF_ICON] is None: if CONF_ICON in updated and updated[CONF_ICON] is None:
updated.pop(CONF_ICON) updated.pop(CONF_ICON)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import cast from typing import Any
import uuid import uuid
import voluptuous as vol import voluptuous as vol
@ -45,7 +45,7 @@ class ResourceYAMLCollection:
return self.data return self.data
class ResourceStorageCollection(collection.StorageCollection): class ResourceStorageCollection(collection.DictStorageCollection):
"""Collection to store resources.""" """Collection to store resources."""
loaded = False loaded = False
@ -67,10 +67,10 @@ class ResourceStorageCollection(collection.StorageCollection):
return {"resources": len(self.async_items() or [])} return {"resources": len(self.async_items() or [])}
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data.""" """Load the data."""
if (data := await self.store.async_load()) is not None: if (store_data := await self.store.async_load()) is not None:
return cast(dict | None, data) return store_data
# Import it from config. # Import it from config.
try: try:
@ -82,20 +82,20 @@ class ResourceStorageCollection(collection.StorageCollection):
return None return None
# Remove it from config and save both resources + config # Remove it from config and save both resources + config
data = conf[CONF_RESOURCES] resources: list[dict[str, Any]] = conf[CONF_RESOURCES]
try: try:
vol.Schema([RESOURCE_SCHEMA])(data) vol.Schema([RESOURCE_SCHEMA])(resources)
except vol.Invalid as err: except vol.Invalid as err:
_LOGGER.warning("Resource import failed. Data invalid: %s", err) _LOGGER.warning("Resource import failed. Data invalid: %s", err)
return None return None
conf.pop(CONF_RESOURCES) conf.pop(CONF_RESOURCES)
for item in data: for item in resources:
item[CONF_ID] = uuid.uuid4().hex item[CONF_ID] = uuid.uuid4().hex
data = {"items": data} data: collection.SerializedStorageCollection = {"items": resources}
await self.store.async_save(data) await self.store.async_save(data)
await self.ll_config.async_save(conf) await self.ll_config.async_save(conf)
@ -113,7 +113,7 @@ class ResourceStorageCollection(collection.StorageCollection):
"""Return unique ID.""" """Return unique ID."""
return uuid.uuid4().hex return uuid.uuid4().hex
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
if not self.loaded: if not self.loaded:
await self.async_load() await self.async_load()
@ -123,4 +123,4 @@ class ResourceStorageCollection(collection.StorageCollection):
if CONF_RESOURCE_TYPE_WS in update_data: if CONF_RESOURCE_TYPE_WS in update_data:
update_data[CONF_TYPE] = update_data.pop(CONF_RESOURCE_TYPE_WS) update_data[CONF_TYPE] = update_data.pop(CONF_RESOURCE_TYPE_WS)
return {**data, **update_data} return {**item, **update_data}

View File

@ -188,7 +188,7 @@ class PersonStore(Store):
return {"items": old_data["persons"]} return {"items": old_data["persons"]}
class PersonStorageCollection(collection.StorageCollection): class PersonStorageCollection(collection.DictStorageCollection):
"""Person collection stored in storage.""" """Person collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -204,7 +204,7 @@ class PersonStorageCollection(collection.StorageCollection):
super().__init__(store, id_manager) super().__init__(store, id_manager)
self.yaml_collection = yaml_collection self.yaml_collection = yaml_collection
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
"""Load the data. """Load the data.
A past bug caused onboarding to create invalid person objects. A past bug caused onboarding to create invalid person objects.
@ -270,16 +270,16 @@ class PersonStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data) update_data = self.UPDATE_SCHEMA(update_data)
user_id = update_data.get(CONF_USER_ID) user_id = update_data.get(CONF_USER_ID)
if user_id is not None and user_id != data.get(CONF_USER_ID): if user_id is not None and user_id != item.get(CONF_USER_ID):
await self._validate_user_id(user_id) await self._validate_user_id(user_id)
return {**data, **update_data} return {**item, **update_data}
async def _validate_user_id(self, user_id): async def _validate_user_id(self, user_id):
"""Validate the used user_id.""" """Validate the used user_id."""

View File

@ -20,8 +20,9 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.collection import ( from homeassistant.helpers.collection import (
CollectionEntity, CollectionEntity,
DictStorageCollection,
IDManager, IDManager,
StorageCollection, SerializedStorageCollection,
StorageCollectionWebsocket, StorageCollectionWebsocket,
YamlCollection, YamlCollection,
sync_entity_lifecycle, sync_entity_lifecycle,
@ -208,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class ScheduleStorageCollection(StorageCollection): class ScheduleStorageCollection(DictStorageCollection):
"""Schedules stored in storage.""" """Schedules stored in storage."""
SCHEMA = vol.Schema(BASE_SCHEMA | STORAGE_SCHEDULE_SCHEMA) SCHEMA = vol.Schema(BASE_SCHEMA | STORAGE_SCHEDULE_SCHEMA)
@ -224,12 +225,12 @@ class ScheduleStorageCollection(StorageCollection):
name: str = info[CONF_NAME] name: str = info[CONF_NAME]
return name return name
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
self.SCHEMA(update_data) self.SCHEMA(update_data)
return data | update_data return item | update_data
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> SerializedStorageCollection | None:
"""Load the data.""" """Load the data."""
if data := await super()._async_load_data(): if data := await super()._async_load_data():
data["items"] = [STORAGE_SCHEMA(item) for item in data["items"]] data["items"] = [STORAGE_SCHEMA(item) for item in data["items"]]

View File

@ -59,7 +59,7 @@ class TagIDManager(collection.IDManager):
return suggestion return suggestion
class TagStorageCollection(collection.StorageCollection): class TagStorageCollection(collection.DictStorageCollection):
"""Tag collection stored in storage.""" """Tag collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -80,9 +80,9 @@ class TagStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[TAG_ID] return info[TAG_ID]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
data = {**data, **self.UPDATE_SCHEMA(update_data)} data = {**item, **self.UPDATE_SCHEMA(update_data)}
# make last_scanned JSON serializeable # make last_scanned JSON serializeable
if LAST_SCANNED in update_data: if LAST_SCANNED in update_data:
data[LAST_SCANNED] = data[LAST_SCANNED].isoformat() data[LAST_SCANNED] = data[LAST_SCANNED].isoformat()

View File

@ -162,7 +162,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class TimerStorageCollection(collection.StorageCollection): class TimerStorageCollection(collection.DictStorageCollection):
"""Timer storage based collection.""" """Timer storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
@ -179,9 +179,9 @@ class TimerStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_NAME] return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
data = {CONF_ID: data[CONF_ID]} | self.CREATE_UPDATE_SCHEMA(update_data) data = {CONF_ID: item[CONF_ID]} | self.CREATE_UPDATE_SCHEMA(update_data)
# make duration JSON serializeable # make duration JSON serializeable
if CONF_DURATION in update_data: if CONF_DURATION in update_data:
data[CONF_DURATION] = _format_timedelta(data[CONF_DURATION]) data[CONF_DURATION] = _format_timedelta(data[CONF_DURATION])

View File

@ -163,7 +163,7 @@ def in_zone(zone: State, latitude: float, longitude: float, radius: float = 0) -
return zone_dist - radius < cast(float, zone.attributes[ATTR_RADIUS]) return zone_dist - radius < cast(float, zone.attributes[ATTR_RADIUS])
class ZoneStorageCollection(collection.StorageCollection): class ZoneStorageCollection(collection.DictStorageCollection):
"""Zone collection stored in storage.""" """Zone collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -178,10 +178,10 @@ class ZoneStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return cast(str, info[CONF_NAME]) return cast(str, info[CONF_NAME])
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data) update_data = self.UPDATE_SCHEMA(update_data)
return {**data, **update_data} return {**item, **update_data}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from itertools import groupby from itertools import groupby
import logging import logging
from operator import attrgetter from operator import attrgetter
from typing import Any, cast from typing import Any, Generic, TypedDict, TypeVar
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -32,6 +32,8 @@ CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated" CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed" CHANGE_REMOVED = "removed"
_T = TypeVar("_T")
@dataclass @dataclass
class CollectionChangeSet: class CollectionChangeSet:
@ -121,20 +123,20 @@ class CollectionEntity(Entity):
"""Handle updated configuration.""" """Handle updated configuration."""
class ObservableCollection(ABC): class ObservableCollection(ABC, Generic[_T]):
"""Base collection type that can be observed.""" """Base collection type that can be observed."""
def __init__(self, id_manager: IDManager | None) -> None: def __init__(self, id_manager: IDManager | None) -> None:
"""Initialize the base collection.""" """Initialize the base collection."""
self.id_manager = id_manager or IDManager() self.id_manager = id_manager or IDManager()
self.data: dict[str, dict] = {} self.data: dict[str, _T] = {}
self.listeners: list[ChangeListener] = [] self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = [] self.change_set_listeners: list[ChangeSetListener] = []
self.id_manager.add_collection(self.data) self.id_manager.add_collection(self.data)
@callback @callback
def async_items(self) -> list[dict]: def async_items(self) -> list[_T]:
"""Return list of items in collection.""" """Return list of items in collection."""
return list(self.data.values()) return list(self.data.values())
@ -169,7 +171,7 @@ class ObservableCollection(ABC):
) )
class YamlCollection(ObservableCollection): class YamlCollection(ObservableCollection[dict]):
"""Offer a collection based on static data.""" """Offer a collection based on static data."""
def __init__( def __init__(
@ -218,12 +220,18 @@ class YamlCollection(ObservableCollection):
await self.notify_changes(change_sets) await self.notify_changes(change_sets)
class StorageCollection(ObservableCollection, ABC): class SerializedStorageCollection(TypedDict):
"""Serialized storage collection."""
items: list[dict[str, Any]]
class StorageCollection(ObservableCollection[_T], ABC):
"""Offer a CRUD interface on top of JSON storage.""" """Offer a CRUD interface on top of JSON storage."""
def __init__( def __init__(
self, self,
store: Store, store: Store[SerializedStorageCollection],
id_manager: IDManager | None = None, id_manager: IDManager | None = None,
) -> None: ) -> None:
"""Initialize the storage collection.""" """Initialize the storage collection."""
@ -242,9 +250,9 @@ class StorageCollection(ObservableCollection, ABC):
"""Home Assistant object.""" """Home Assistant object."""
return self.store.hass return self.store.hass
async def _async_load_data(self) -> dict | None: async def _async_load_data(self) -> SerializedStorageCollection | None:
"""Load the data.""" """Load the data."""
return cast(dict | None, await self.store.async_load()) return await self.store.async_load()
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the storage Manager.""" """Load the storage Manager."""
@ -254,7 +262,7 @@ class StorageCollection(ObservableCollection, ABC):
raw_storage = {"items": []} raw_storage = {"items": []}
for item in raw_storage["items"]: for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item self.data[item[CONF_ID]] = self._deserialize_item(item)
await self.notify_changes( await self.notify_changes(
[ [
@ -273,21 +281,35 @@ class StorageCollection(ObservableCollection, ABC):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
@abstractmethod @abstractmethod
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: _T, update_data: dict) -> _T:
"""Return a new updated data object.""" """Return a new updated item."""
async def async_create_item(self, data: dict) -> dict: @abstractmethod
def _create_item(self, item_id: str, data: dict) -> _T:
"""Create an item from validated config."""
@abstractmethod
def _deserialize_item(self, data: dict) -> _T:
"""Create an item from its serialized representation."""
@abstractmethod
def _serialize_item(self, item_id: str, item: _T) -> dict:
"""Return the serialized representation of an item.
The serialized representation must include the item_id in the "id" key.
"""
async def async_create_item(self, data: dict) -> _T:
"""Create a new item.""" """Create a new item."""
item = await self._process_create_data(data) validated_data = await self._process_create_data(data)
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item)) item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data))
self.data[item[CONF_ID]] = item item = self._create_item(item_id, validated_data)
self.data[item_id] = item
self._async_schedule_save() self._async_schedule_save()
await self.notify_changes( await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)])
[CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)]
)
return item return item
async def async_update_item(self, item_id: str, updates: dict) -> dict: async def async_update_item(self, item_id: str, updates: dict) -> _T:
"""Update item.""" """Update item."""
if item_id not in self.data: if item_id not in self.data:
raise ItemNotFound(item_id) raise ItemNotFound(item_id)
@ -320,13 +342,34 @@ class StorageCollection(ObservableCollection, ABC):
@callback @callback
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
"""Schedule saving the area registry.""" """Schedule saving the collection."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY) self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self) -> dict: def _data_to_save(self) -> SerializedStorageCollection:
"""Return data of area registry to store in a file.""" """Return JSON-compatible date for storing to file."""
return {"items": list(self.data.values())} return {
"items": [
self._serialize_item(item_id, item)
for item_id, item in self.data.items()
]
}
class DictStorageCollection(StorageCollection[dict]):
"""A specialized StorageCollection where the items are untyped dicts."""
def _create_item(self, item_id: str, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return {CONF_ID: item_id} | data
def _deserialize_item(self, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return data
def _serialize_item(self, item_id: str, item: dict) -> dict:
"""Return the serialized representation of an item."""
return item
class IDLessCollection(YamlCollection): class IDLessCollection(YamlCollection):

View File

@ -82,7 +82,7 @@ class MockObservableCollection(collection.ObservableCollection):
return entity_class.from_storage(config) return entity_class.from_storage(config)
class MockStorageCollection(collection.StorageCollection): class MockStorageCollection(collection.DictStorageCollection):
"""Mock storage collection.""" """Mock storage collection."""
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict) -> dict:
@ -96,9 +96,9 @@ class MockStorageCollection(collection.StorageCollection):
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info["name"] return info["name"]
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, item: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
return {**data, **update_data} return {**item, **update_data}
def test_id_manager() -> None: def test_id_manager() -> None: