Optimize storage collection entity operations with asyncio.gather (#48352)

This commit is contained in:
J. Nick Koston 2021-04-03 23:35:33 -10:00 committed by GitHub
parent c1e788e665
commit 3bc583607f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,8 +4,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from itertools import groupby
import logging import logging
from typing import Any, Awaitable, Callable, Iterable, Optional, cast from typing import Any, Awaitable, Callable, Coroutine, Iterable, Optional, cast
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -54,6 +55,8 @@ ChangeListener = Callable[
Awaitable[None], Awaitable[None],
] ]
ChangeSetListener = Callable[[Iterable[CollectionChangeSet]], Awaitable[None]]
class CollectionError(HomeAssistantError): class CollectionError(HomeAssistantError):
"""Base class for collection related errors.""" """Base class for collection related errors."""
@ -105,6 +108,7 @@ class ObservableCollection(ABC):
self.id_manager = id_manager or IDManager() self.id_manager = id_manager or IDManager()
self.data: dict[str, dict] = {} self.data: dict[str, dict] = {}
self.listeners: list[ChangeListener] = [] self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = []
self.id_manager.add_collection(self.data) self.id_manager.add_collection(self.data)
@ -121,6 +125,14 @@ class ObservableCollection(ABC):
""" """
self.listeners.append(listener) self.listeners.append(listener)
@callback
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
"""Add a listener for a full change set.
Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change.""" """Notify listeners of a change."""
await asyncio.gather( await asyncio.gather(
@ -128,7 +140,11 @@ class ObservableCollection(ABC):
listener(change_set.change_type, change_set.item_id, change_set.item) listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners for listener in self.listeners
for change_set in change_sets for change_set in change_sets
] ],
*[
change_set_listener(change_sets)
for change_set_listener in self.change_set_listeners
],
) )
@ -311,29 +327,55 @@ def sync_entity_lifecycle(
) -> None: ) -> None:
"""Map a collection to an entity component.""" """Map a collection to an entity component."""
entities = {} entities = {}
ent_reg = entity_registry.async_get(hass)
async def _collection_changed(change_type: str, item_id: str, config: dict) -> None: async def _add_entity(change_set: CollectionChangeSet) -> Entity:
"""Handle a collection change.""" entities[change_set.item_id] = create_entity(change_set.item)
if change_type == CHANGE_ADDED: return entities[change_set.item_id]
entity = create_entity(config)
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return
if change_type == CHANGE_REMOVED: async def _remove_entity(change_set: CollectionChangeSet) -> None:
ent_reg = await entity_registry.async_get_registry(hass) ent_to_remove = ent_reg.async_get_entity_id(
ent_to_remove = ent_reg.async_get_entity_id(domain, platform, item_id) domain, platform, change_set.item_id
)
if ent_to_remove is not None: if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove) ent_reg.async_remove(ent_to_remove)
else: else:
await entities[item_id].async_remove(force_remove=True) await entities[change_set.item_id].async_remove(force_remove=True)
entities.pop(item_id) entities.pop(change_set.item_id)
return
# CHANGE_UPDATED async def _update_entity(change_set: CollectionChangeSet) -> None:
await entities[item_id].async_update_config(config) # type: ignore await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore
collection.async_add_listener(_collection_changed) _func_map: dict[
str, Callable[[CollectionChangeSet], Coroutine[Any, Any, Entity | None]]
] = {
CHANGE_ADDED: _add_entity,
CHANGE_REMOVED: _remove_entity,
CHANGE_UPDATED: _update_entity,
}
async def _collection_changed(change_sets: Iterable[CollectionChangeSet]) -> None:
"""Handle a collection change."""
# Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group
# the same change type.
for _, grouped in groupby(
change_sets, lambda change_set: change_set.change_type
):
new_entities = [
entity
for entity in await asyncio.gather(
*[
_func_map[change_set.change_type](change_set)
for change_set in grouped
]
)
if entity is not None
]
if new_entities:
await entity_component.async_add_entities(new_entities)
collection.async_add_change_set_listener(_collection_changed)
class StorageCollectionWebsocket: class StorageCollectionWebsocket: