diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 9e7c6061987..6733b1d3dbd 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -1,8 +1,9 @@ """Helper to deal with YAML + storage.""" from abc import ABC, abstractmethod import asyncio +from dataclasses import dataclass import logging -from typing import Any, Awaitable, Callable, Dict, List, Optional, cast +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast import voluptuous as vol from voluptuous.humanize import humanize_error @@ -26,6 +27,20 @@ CHANGE_UPDATED = "updated" CHANGE_REMOVED = "removed" +@dataclass +class CollectionChangeSet: + """Class to represent a change set. + + change_type: One of CHANGE_* + item_id: The id of the item + item: The item + """ + + change_type: str + item_id: str + item: Any + + ChangeListener = Callable[ [ # Change type @@ -105,11 +120,14 @@ class ObservableCollection(ABC): """ self.listeners.append(listener) - async def notify_change(self, change_type: str, item_id: str, item: dict) -> None: + async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: """Notify listeners of a change.""" - self.logger.debug("%s %s: %s", change_type, item_id, item) await asyncio.gather( - *[listener(change_type, item_id, item) for listener in self.listeners] + *[ + listener(change_set.change_type, change_set.item_id, change_set.item) + for listener in self.listeners + for change_set in change_sets + ] ) @@ -118,9 +136,10 @@ class YamlCollection(ObservableCollection): async def async_load(self, data: List[dict]) -> None: """Load the YAML collection. Overrides existing data.""" + old_ids = set(self.data) - tasks = [] + change_sets = [] for item in data: item_id = item[CONF_ID] @@ -135,15 +154,15 @@ class YamlCollection(ObservableCollection): event = CHANGE_ADDED self.data[item_id] = item - tasks.append(self.notify_change(event, item_id, item)) + change_sets.append(CollectionChangeSet(event, item_id, item)) for item_id in old_ids: - tasks.append( - self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id)) + change_sets.append( + CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id)) ) - if tasks: - await asyncio.gather(*tasks) + if change_sets: + await self.notify_changes(change_sets) class StorageCollection(ObservableCollection): @@ -178,9 +197,9 @@ class StorageCollection(ObservableCollection): for item in raw_storage["items"]: self.data[item[CONF_ID]] = item - await asyncio.gather( - *[ - self.notify_change(CHANGE_ADDED, item[CONF_ID], item) + await self.notify_changes( + [ + CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item) for item in raw_storage["items"] ] ) @@ -204,7 +223,9 @@ class StorageCollection(ObservableCollection): item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item)) self.data[item[CONF_ID]] = item self._async_schedule_save() - await self.notify_change(CHANGE_ADDED, item[CONF_ID], item) + await self.notify_changes( + [CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)] + ) return item async def async_update_item(self, item_id: str, updates: dict) -> dict: @@ -222,7 +243,9 @@ class StorageCollection(ObservableCollection): self.data[item_id] = updated self._async_schedule_save() - await self.notify_change(CHANGE_UPDATED, item_id, updated) + await self.notify_changes( + [CollectionChangeSet(CHANGE_UPDATED, item_id, updated)] + ) return self.data[item_id] @@ -234,7 +257,7 @@ class StorageCollection(ObservableCollection): item = self.data.pop(item_id) self._async_schedule_save() - await self.notify_change(CHANGE_REMOVED, item_id, item) + await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)]) @callback def _async_schedule_save(self) -> None: @@ -254,9 +277,9 @@ class IDLessCollection(ObservableCollection): async def async_load(self, data: List[dict]) -> None: """Load the collection. Overrides existing data.""" - await asyncio.gather( - *[ - self.notify_change(CHANGE_REMOVED, item_id, item) + await self.notify_changes( + [ + CollectionChangeSet(CHANGE_REMOVED, item_id, item) for item_id, item in list(self.data.items()) ] ) @@ -269,9 +292,9 @@ class IDLessCollection(ObservableCollection): self.data[item_id] = item - await asyncio.gather( - *[ - self.notify_change(CHANGE_ADDED, item_id, item) + await self.notify_changes( + [ + CollectionChangeSet(CHANGE_ADDED, item_id, item) for item_id, item in self.data.items() ] ) diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index 11f1534defb..d5a8526b6da 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -91,7 +91,9 @@ async def test_observable_collection(): assert coll.async_items() == [1] changes = track_changes(coll) - await coll.notify_change("mock_type", "mock_id", {"mock": "item"}) + await coll.notify_changes( + [collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})] + ) assert len(changes) == 1 assert changes[0] == ("mock_type", "mock_id", {"mock": "item"}) @@ -226,25 +228,35 @@ async def test_attach_entity_component_collection(hass): coll = collection.ObservableCollection(_LOGGER) collection.attach_entity_component_collection(ent_comp, coll, MockEntity) - await coll.notify_change( - collection.CHANGE_ADDED, - "mock_id", - {"id": "mock_id", "state": "initial", "name": "Mock 1"}, + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_ADDED, + "mock_id", + {"id": "mock_id", "state": "initial", "name": "Mock 1"}, + ) + ], ) assert hass.states.get("test.mock_1").name == "Mock 1" assert hass.states.get("test.mock_1").state == "initial" - await coll.notify_change( - collection.CHANGE_UPDATED, - "mock_id", - {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_UPDATED, + "mock_id", + {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, + ) + ], ) assert hass.states.get("test.mock_1").name == "Mock 1 updated" assert hass.states.get("test.mock_1").state == "second" - await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None) + await coll.notify_changes( + [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + ) assert hass.states.get("test.mock_1") is None