Gather all collection listeners and changes at the same time (#42497)

This commit is contained in:
J. Nick Koston 2020-10-29 04:06:55 -05:00 committed by GitHub
parent 6b29648cfc
commit c8f00a7b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 32 deletions

View File

@ -1,8 +1,9 @@
"""Helper to deal with YAML + storage.""" """Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from dataclasses import dataclass
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -26,6 +27,20 @@ CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed" CHANGE_REMOVED = "removed"
@dataclass
class CollectionChangeSet:
"""Class to represent a change set.
change_type: One of CHANGE_*
item_id: The id of the item
item: The item
"""
change_type: str
item_id: str
item: Any
ChangeListener = Callable[ ChangeListener = Callable[
[ [
# Change type # Change type
@ -105,11 +120,14 @@ class ObservableCollection(ABC):
""" """
self.listeners.append(listener) self.listeners.append(listener)
async def notify_change(self, change_type: str, item_id: str, item: dict) -> None: async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change.""" """Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
await asyncio.gather( await asyncio.gather(
*[listener(change_type, item_id, item) for listener in self.listeners] *[
listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners
for change_set in change_sets
]
) )
@ -118,9 +136,10 @@ class YamlCollection(ObservableCollection):
async def async_load(self, data: List[dict]) -> None: async def async_load(self, data: List[dict]) -> None:
"""Load the YAML collection. Overrides existing data.""" """Load the YAML collection. Overrides existing data."""
old_ids = set(self.data) old_ids = set(self.data)
tasks = [] change_sets = []
for item in data: for item in data:
item_id = item[CONF_ID] item_id = item[CONF_ID]
@ -135,15 +154,15 @@ class YamlCollection(ObservableCollection):
event = CHANGE_ADDED event = CHANGE_ADDED
self.data[item_id] = item self.data[item_id] = item
tasks.append(self.notify_change(event, item_id, item)) change_sets.append(CollectionChangeSet(event, item_id, item))
for item_id in old_ids: for item_id in old_ids:
tasks.append( change_sets.append(
self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id)) CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id))
) )
if tasks: if change_sets:
await asyncio.gather(*tasks) await self.notify_changes(change_sets)
class StorageCollection(ObservableCollection): class StorageCollection(ObservableCollection):
@ -178,9 +197,9 @@ class StorageCollection(ObservableCollection):
for item in raw_storage["items"]: for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item self.data[item[CONF_ID]] = item
await asyncio.gather( await self.notify_changes(
*[ [
self.notify_change(CHANGE_ADDED, item[CONF_ID], item) CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"] for item in raw_storage["items"]
] ]
) )
@ -204,7 +223,9 @@ class StorageCollection(ObservableCollection):
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item)) item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item self.data[item[CONF_ID]] = item
self._async_schedule_save() self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item) await self.notify_changes(
[CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)]
)
return item return item
async def async_update_item(self, item_id: str, updates: dict) -> dict: async def async_update_item(self, item_id: str, updates: dict) -> dict:
@ -222,7 +243,9 @@ class StorageCollection(ObservableCollection):
self.data[item_id] = updated self.data[item_id] = updated
self._async_schedule_save() self._async_schedule_save()
await self.notify_change(CHANGE_UPDATED, item_id, updated) await self.notify_changes(
[CollectionChangeSet(CHANGE_UPDATED, item_id, updated)]
)
return self.data[item_id] return self.data[item_id]
@ -234,7 +257,7 @@ class StorageCollection(ObservableCollection):
item = self.data.pop(item_id) item = self.data.pop(item_id)
self._async_schedule_save() self._async_schedule_save()
await self.notify_change(CHANGE_REMOVED, item_id, item) await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)])
@callback @callback
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
@ -254,9 +277,9 @@ class IDLessCollection(ObservableCollection):
async def async_load(self, data: List[dict]) -> None: async def async_load(self, data: List[dict]) -> None:
"""Load the collection. Overrides existing data.""" """Load the collection. Overrides existing data."""
await asyncio.gather( await self.notify_changes(
*[ [
self.notify_change(CHANGE_REMOVED, item_id, item) CollectionChangeSet(CHANGE_REMOVED, item_id, item)
for item_id, item in list(self.data.items()) for item_id, item in list(self.data.items())
] ]
) )
@ -269,9 +292,9 @@ class IDLessCollection(ObservableCollection):
self.data[item_id] = item self.data[item_id] = item
await asyncio.gather( await self.notify_changes(
*[ [
self.notify_change(CHANGE_ADDED, item_id, item) CollectionChangeSet(CHANGE_ADDED, item_id, item)
for item_id, item in self.data.items() for item_id, item in self.data.items()
] ]
) )

View File

@ -91,7 +91,9 @@ async def test_observable_collection():
assert coll.async_items() == [1] assert coll.async_items() == [1]
changes = track_changes(coll) changes = track_changes(coll)
await coll.notify_change("mock_type", "mock_id", {"mock": "item"}) await coll.notify_changes(
[collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})]
)
assert len(changes) == 1 assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"}) assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
@ -226,25 +228,35 @@ async def test_attach_entity_component_collection(hass):
coll = collection.ObservableCollection(_LOGGER) coll = collection.ObservableCollection(_LOGGER)
collection.attach_entity_component_collection(ent_comp, coll, MockEntity) collection.attach_entity_component_collection(ent_comp, coll, MockEntity)
await coll.notify_change( await coll.notify_changes(
collection.CHANGE_ADDED, [
"mock_id", collection.CollectionChangeSet(
{"id": "mock_id", "state": "initial", "name": "Mock 1"}, collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
],
) )
assert hass.states.get("test.mock_1").name == "Mock 1" assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial" assert hass.states.get("test.mock_1").state == "initial"
await coll.notify_change( await coll.notify_changes(
collection.CHANGE_UPDATED, [
"mock_id", collection.CollectionChangeSet(
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
],
) )
assert hass.states.get("test.mock_1").name == "Mock 1 updated" assert hass.states.get("test.mock_1").name == "Mock 1 updated"
assert hass.states.get("test.mock_1").state == "second" assert hass.states.get("test.mock_1").state == "second"
await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None) await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
)
assert hass.states.get("test.mock_1") is None assert hass.states.get("test.mock_1") is None