Rename collection.CollectionChangeSet to collection.CollectionChange (#119532)

This commit is contained in:
Erik Montnemery 2024-06-14 08:54:37 +02:00 committed by GitHub
parent 9082dc2a79
commit 003f216820
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 45 deletions

View File

@ -109,7 +109,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
async def _pipelines_updated( async def _pipelines_updated(
self, change_sets: Iterable[collection.CollectionChangeSet] self, change_set: Iterable[collection.CollectionChange]
) -> None: ) -> None:
"""Handle pipeline update.""" """Handle pipeline update."""
self._update_options() self._update_options()

View File

@ -39,8 +39,8 @@ _EntityT = TypeVar("_EntityT", bound=Entity, default=Entity)
@dataclass(slots=True) @dataclass(slots=True)
class CollectionChangeSet: class CollectionChange:
"""Class to represent a change set. """Class to represent an item in a change set.
change_type: One of CHANGE_* change_type: One of CHANGE_*
item_id: The id of the item item_id: The id of the item
@ -64,7 +64,7 @@ type ChangeListener = Callable[
Awaitable[None], Awaitable[None],
] ]
type ChangeSetListener = Callable[[Iterable[CollectionChangeSet]], Awaitable[None]] type ChangeSetListener = Callable[[Iterable[CollectionChange]], Awaitable[None]]
class CollectionError(HomeAssistantError): class CollectionError(HomeAssistantError):
@ -163,16 +163,16 @@ class ObservableCollection[_ItemT](ABC):
self.change_set_listeners.append(listener) self.change_set_listeners.append(listener)
return partial(self.change_set_listeners.remove, listener) return partial(self.change_set_listeners.remove, listener)
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: async def notify_changes(self, change_set: Iterable[CollectionChange]) -> None:
"""Notify listeners of a change.""" """Notify listeners of a change."""
await asyncio.gather( await asyncio.gather(
*( *(
listener(change_set.change_type, change_set.item_id, change_set.item) listener(change.change_type, change.item_id, change.item)
for listener in self.listeners for listener in self.listeners
for change_set in change_sets for change in change_set
), ),
*( *(
change_set_listener(change_sets) change_set_listener(change_set)
for change_set_listener in self.change_set_listeners for change_set_listener in self.change_set_listeners
), ),
) )
@ -201,7 +201,7 @@ class YamlCollection(ObservableCollection[dict]):
"""Load the YAML collection. Overrides existing data.""" """Load the YAML collection. Overrides existing data."""
old_ids = set(self.data) old_ids = set(self.data)
change_sets = [] change_set = []
for item in data: for item in data:
item_id = item[CONF_ID] item_id = item[CONF_ID]
@ -216,15 +216,15 @@ class YamlCollection(ObservableCollection[dict]):
event = CHANGE_ADDED event = CHANGE_ADDED
self.data[item_id] = item self.data[item_id] = item
change_sets.append(CollectionChangeSet(event, item_id, item)) change_set.append(CollectionChange(event, item_id, item))
change_sets.extend( change_set.extend(
CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id)) CollectionChange(CHANGE_REMOVED, item_id, self.data.pop(item_id))
for item_id in old_ids for item_id in old_ids
) )
if change_sets: if change_set:
await self.notify_changes(change_sets) await self.notify_changes(change_set)
class SerializedStorageCollection(TypedDict): class SerializedStorageCollection(TypedDict):
@ -273,7 +273,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection](
await self.notify_changes( await self.notify_changes(
[ [
CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item) CollectionChange(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"] for item in raw_storage["items"]
] ]
) )
@ -313,7 +313,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection](
item = self._create_item(item_id, validated_data) item = self._create_item(item_id, validated_data)
self.data[item_id] = item self.data[item_id] = item
self._async_schedule_save() self._async_schedule_save()
await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)]) await self.notify_changes([CollectionChange(CHANGE_ADDED, item_id, item)])
return item return item
async def async_update_item(self, item_id: str, updates: dict) -> _ItemT: async def async_update_item(self, item_id: str, updates: dict) -> _ItemT:
@ -331,9 +331,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection](
self.data[item_id] = updated self.data[item_id] = updated
self._async_schedule_save() self._async_schedule_save()
await self.notify_changes( await self.notify_changes([CollectionChange(CHANGE_UPDATED, item_id, updated)])
[CollectionChangeSet(CHANGE_UPDATED, item_id, updated)]
)
return self.data[item_id] return self.data[item_id]
@ -345,7 +343,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection](
item = self.data.pop(item_id) item = self.data.pop(item_id)
self._async_schedule_save() self._async_schedule_save()
await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)]) await self.notify_changes([CollectionChange(CHANGE_REMOVED, item_id, item)])
@callback @callback
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
@ -398,7 +396,7 @@ class IDLessCollection(YamlCollection):
"""Load the collection. Overrides existing data.""" """Load the collection. Overrides existing data."""
await self.notify_changes( await self.notify_changes(
[ [
CollectionChangeSet(CHANGE_REMOVED, item_id, item) CollectionChange(CHANGE_REMOVED, item_id, item)
for item_id, item in list(self.data.items()) for item_id, item in list(self.data.items())
] ]
) )
@ -413,7 +411,7 @@ class IDLessCollection(YamlCollection):
await self.notify_changes( await self.notify_changes(
[ [
CollectionChangeSet(CHANGE_ADDED, item_id, item) CollectionChange(CHANGE_ADDED, item_id, item)
for item_id, item in self.data.items() for item_id, item in self.data.items()
] ]
) )
@ -444,14 +442,14 @@ class _CollectionLifeCycle(Generic[_EntityT]):
self.entities.pop(item_id, None) self.entities.pop(item_id, None)
@callback @callback
def _add_entity(self, change_set: CollectionChangeSet) -> CollectionEntity: def _add_entity(self, change_set: CollectionChange) -> CollectionEntity:
item_id = change_set.item_id item_id = change_set.item_id
entity = self.collection.create_entity(self.entity_class, change_set.item) entity = self.collection.create_entity(self.entity_class, change_set.item)
self.entities[item_id] = entity self.entities[item_id] = entity
entity.async_on_remove(partial(self._entity_removed, item_id)) entity.async_on_remove(partial(self._entity_removed, item_id))
return entity return entity
async def _remove_entity(self, change_set: CollectionChangeSet) -> None: async def _remove_entity(self, change_set: CollectionChange) -> None:
item_id = change_set.item_id item_id = change_set.item_id
ent_reg = self.ent_reg ent_reg = self.ent_reg
entities = self.entities entities = self.entities
@ -464,29 +462,27 @@ class _CollectionLifeCycle(Generic[_EntityT]):
# the entity registry event handled by Entity._async_registry_updated # the entity registry event handled by Entity._async_registry_updated
entities.pop(item_id, None) entities.pop(item_id, None)
async def _update_entity(self, change_set: CollectionChangeSet) -> None: async def _update_entity(self, change_set: CollectionChange) -> None:
if entity := self.entities.get(change_set.item_id): if entity := self.entities.get(change_set.item_id):
await entity.async_update_config(change_set.item) await entity.async_update_config(change_set.item)
async def _collection_changed( async def _collection_changed(self, change_set: Iterable[CollectionChange]) -> None:
self, change_sets: Iterable[CollectionChangeSet]
) -> None:
"""Handle a collection change.""" """Handle a collection change."""
# Create a new bucket every time we have a different change type # Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group # to ensure operations happen in order. We only group
# the same change type. # the same change type.
new_entities: list[CollectionEntity] = [] new_entities: list[CollectionEntity] = []
coros: list[Coroutine[Any, Any, CollectionEntity | None]] = [] coros: list[Coroutine[Any, Any, CollectionEntity | None]] = []
grouped: Iterable[CollectionChangeSet] grouped: Iterable[CollectionChange]
for _, grouped in groupby(change_sets, _GROUP_BY_KEY): for _, grouped in groupby(change_set, _GROUP_BY_KEY):
for change_set in grouped: for change in grouped:
change_type = change_set.change_type change_type = change.change_type
if change_type == CHANGE_ADDED: if change_type == CHANGE_ADDED:
new_entities.append(self._add_entity(change_set)) new_entities.append(self._add_entity(change))
elif change_type == CHANGE_REMOVED: elif change_type == CHANGE_REMOVED:
coros.append(self._remove_entity(change_set)) coros.append(self._remove_entity(change))
elif change_type == CHANGE_UPDATED: elif change_type == CHANGE_UPDATED:
coros.append(self._update_entity(change_set)) coros.append(self._update_entity(change))
if coros: if coros:
await asyncio.gather(*coros) await asyncio.gather(*coros)

View File

@ -124,7 +124,7 @@ async def test_observable_collection() -> None:
changes = track_changes(coll) changes = track_changes(coll)
await coll.notify_changes( await coll.notify_changes(
[collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})] [collection.CollectionChange("mock_type", "mock_id", {"mock": "item"})]
) )
assert len(changes) == 1 assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"}) assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
@ -263,7 +263,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None:
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_ADDED, collection.CHANGE_ADDED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"}, {"id": "mock_id", "state": "initial", "name": "Mock 1"},
@ -276,7 +276,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None:
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_UPDATED, collection.CHANGE_UPDATED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, {"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
@ -288,7 +288,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None:
assert hass.states.get("test.mock_1").state == "second" assert hass.states.get("test.mock_1").state == "second"
await coll.notify_changes( await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)],
) )
assert hass.states.get("test.mock_1") is None assert hass.states.get("test.mock_1") is None
@ -331,7 +331,7 @@ async def test_entity_component_collection_abort(
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_ADDED, collection.CHANGE_ADDED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"}, {"id": "mock_id", "state": "initial", "name": "Mock 1"},
@ -343,7 +343,7 @@ async def test_entity_component_collection_abort(
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_UPDATED, collection.CHANGE_UPDATED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, {"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
@ -355,7 +355,7 @@ async def test_entity_component_collection_abort(
assert len(async_update_config_calls) == 0 assert len(async_update_config_calls) == 0
await coll.notify_changes( await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)],
) )
assert hass.states.get("test.mock_1") is None assert hass.states.get("test.mock_1") is None
@ -395,7 +395,7 @@ async def test_entity_component_collection_entity_removed(
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_ADDED, collection.CHANGE_ADDED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"}, {"id": "mock_id", "state": "initial", "name": "Mock 1"},
@ -413,7 +413,7 @@ async def test_entity_component_collection_entity_removed(
await coll.notify_changes( await coll.notify_changes(
[ [
collection.CollectionChangeSet( collection.CollectionChange(
collection.CHANGE_UPDATED, collection.CHANGE_UPDATED,
"mock_id", "mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, {"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
@ -425,7 +425,7 @@ async def test_entity_component_collection_entity_removed(
assert len(async_update_config_calls) == 0 assert len(async_update_config_calls) == 0
await coll.notify_changes( await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)],
) )
assert hass.states.get("test.mock_1") is None assert hass.states.get("test.mock_1") is None