Optimize storage collection entity operations with asyncio.gather (#48352)

This commit is contained in:
J. Nick Koston 2021-04-03 23:35:33 -10:00 committed by GitHub
parent c1e788e665
commit 3bc583607f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,8 +4,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
from itertools import groupby
import logging
from typing import Any, Awaitable, Callable, Iterable, Optional, cast
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -54,6 +55,8 @@ ChangeListener = Callable[
Awaitable[None],
]
ChangeSetListener = Callable[[Iterable[CollectionChangeSet]], Awaitable[None]]
class CollectionError(HomeAssistantError):
"""Base class for collection related errors."""
@ -105,6 +108,7 @@ class ObservableCollection(ABC):
self.id_manager = id_manager or IDManager()
self.data: dict[str, dict] = {}
self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = []
self.id_manager.add_collection(self.data)
@ -121,6 +125,14 @@ class ObservableCollection(ABC):
"""
self.listeners.append(listener)
@callback
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
"""Add a listener for a full change set.
Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
await asyncio.gather(
@ -128,7 +140,11 @@ class ObservableCollection(ABC):
listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners
for change_set in change_sets
]
],
*[
change_set_listener(change_sets)
for change_set_listener in self.change_set_listeners
],
)
@ -311,29 +327,55 @@ def sync_entity_lifecycle(
) -> None:
"""Map a collection to an entity component."""
entities = {}
ent_reg = entity_registry.async_get(hass)
async def _collection_changed(change_type: str, item_id: str, config: dict) -> None:
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
entities[change_set.item_id] = create_entity(change_set.item)
return entities[change_set.item_id]
async def _remove_entity(change_set: CollectionChangeSet) -> None:
ent_to_remove = ent_reg.async_get_entity_id(
domain, platform, change_set.item_id
)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
else:
await entities[change_set.item_id].async_remove(force_remove=True)
entities.pop(change_set.item_id)
async def _update_entity(change_set: CollectionChangeSet) -> None:
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore
_func_map: dict[
str, Callable[[CollectionChangeSet], Coroutine[Any, Any, Entity | None]]
] = {
CHANGE_ADDED: _add_entity,
CHANGE_REMOVED: _remove_entity,
CHANGE_UPDATED: _update_entity,
}
async def _collection_changed(change_sets: Iterable[CollectionChangeSet]) -> None:
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(config)
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return
# Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group
# the same change type.
for _, grouped in groupby(
change_sets, lambda change_set: change_set.change_type
):
new_entities = [
entity
for entity in await asyncio.gather(
*[
_func_map[change_set.change_type](change_set)
for change_set in grouped
]
)
if entity is not None
]
if new_entities:
await entity_component.async_add_entities(new_entities)
if change_type == CHANGE_REMOVED:
ent_reg = await entity_registry.async_get_registry(hass)
ent_to_remove = ent_reg.async_get_entity_id(domain, platform, item_id)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
else:
await entities[item_id].async_remove(force_remove=True)
entities.pop(item_id)
return
# CHANGE_UPDATED
await entities[item_id].async_update_config(config) # type: ignore
collection.async_add_listener(_collection_changed)
collection.async_add_change_set_listener(_collection_changed)
class StorageCollectionWebsocket: