Gather all collection listeners and changes at the same time (#42497)

This commit is contained in:
J. Nick Koston 2020-10-29 04:06:55 -05:00 committed by GitHub
parent 6b29648cfc
commit c8f00a7b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 32 deletions

View File

@ -1,8 +1,9 @@
"""Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -26,6 +27,20 @@ CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
@dataclass
class CollectionChangeSet:
"""Class to represent a change set.
change_type: One of CHANGE_*
item_id: The id of the item
item: The item
"""
change_type: str
item_id: str
item: Any
ChangeListener = Callable[
[
# Change type
@ -105,11 +120,14 @@ class ObservableCollection(ABC):
"""
self.listeners.append(listener)
async def notify_change(self, change_type: str, item_id: str, item: dict) -> None:
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
await asyncio.gather(
*[listener(change_type, item_id, item) for listener in self.listeners]
*[
listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners
for change_set in change_sets
]
)
@ -118,9 +136,10 @@ class YamlCollection(ObservableCollection):
async def async_load(self, data: List[dict]) -> None:
"""Load the YAML collection. Overrides existing data."""
old_ids = set(self.data)
tasks = []
change_sets = []
for item in data:
item_id = item[CONF_ID]
@ -135,15 +154,15 @@ class YamlCollection(ObservableCollection):
event = CHANGE_ADDED
self.data[item_id] = item
tasks.append(self.notify_change(event, item_id, item))
change_sets.append(CollectionChangeSet(event, item_id, item))
for item_id in old_ids:
tasks.append(
self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id))
change_sets.append(
CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id))
)
if tasks:
await asyncio.gather(*tasks)
if change_sets:
await self.notify_changes(change_sets)
class StorageCollection(ObservableCollection):
@ -178,9 +197,9 @@ class StorageCollection(ObservableCollection):
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item
await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"]
]
)
@ -204,7 +223,9 @@ class StorageCollection(ObservableCollection):
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
await self.notify_changes(
[CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)]
)
return item
async def async_update_item(self, item_id: str, updates: dict) -> dict:
@ -222,7 +243,9 @@ class StorageCollection(ObservableCollection):
self.data[item_id] = updated
self._async_schedule_save()
await self.notify_change(CHANGE_UPDATED, item_id, updated)
await self.notify_changes(
[CollectionChangeSet(CHANGE_UPDATED, item_id, updated)]
)
return self.data[item_id]
@ -234,7 +257,7 @@ class StorageCollection(ObservableCollection):
item = self.data.pop(item_id)
self._async_schedule_save()
await self.notify_change(CHANGE_REMOVED, item_id, item)
await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)])
@callback
def _async_schedule_save(self) -> None:
@ -254,9 +277,9 @@ class IDLessCollection(ObservableCollection):
async def async_load(self, data: List[dict]) -> None:
"""Load the collection. Overrides existing data."""
await asyncio.gather(
*[
self.notify_change(CHANGE_REMOVED, item_id, item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_REMOVED, item_id, item)
for item_id, item in list(self.data.items())
]
)
@ -269,9 +292,9 @@ class IDLessCollection(ObservableCollection):
self.data[item_id] = item
await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item_id, item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item_id, item)
for item_id, item in self.data.items()
]
)

View File

@ -91,7 +91,9 @@ async def test_observable_collection():
assert coll.async_items() == [1]
changes = track_changes(coll)
await coll.notify_change("mock_type", "mock_id", {"mock": "item"})
await coll.notify_changes(
[collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})]
)
assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
@ -226,25 +228,35 @@ async def test_attach_entity_component_collection(hass):
coll = collection.ObservableCollection(_LOGGER)
collection.attach_entity_component_collection(ent_comp, coll, MockEntity)
await coll.notify_change(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
],
)
assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial"
await coll.notify_change(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
],
)
assert hass.states.get("test.mock_1").name == "Mock 1 updated"
assert hass.states.get("test.mock_1").state == "second"
await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None)
await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
)
assert hass.states.get("test.mock_1") is None