diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index 43ed003f65d..5d011424e6e 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -109,7 +109,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): self.async_write_ha_state() async def _pipelines_updated( - self, change_sets: Iterable[collection.CollectionChangeSet] + self, change_set: Iterable[collection.CollectionChange] ) -> None: """Handle pipeline update.""" self._update_options() diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 1b63d95864a..4691bc804fd 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -39,8 +39,8 @@ _EntityT = TypeVar("_EntityT", bound=Entity, default=Entity) @dataclass(slots=True) -class CollectionChangeSet: - """Class to represent a change set. +class CollectionChange: + """Class to represent an item in a change set. change_type: One of CHANGE_* item_id: The id of the item @@ -64,7 +64,7 @@ type ChangeListener = Callable[ Awaitable[None], ] -type ChangeSetListener = Callable[[Iterable[CollectionChangeSet]], Awaitable[None]] +type ChangeSetListener = Callable[[Iterable[CollectionChange]], Awaitable[None]] class CollectionError(HomeAssistantError): @@ -163,16 +163,16 @@ class ObservableCollection[_ItemT](ABC): self.change_set_listeners.append(listener) return partial(self.change_set_listeners.remove, listener) - async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: + async def notify_changes(self, change_set: Iterable[CollectionChange]) -> None: """Notify listeners of a change.""" await asyncio.gather( *( - listener(change_set.change_type, change_set.item_id, change_set.item) + listener(change.change_type, change.item_id, change.item) for listener in self.listeners - for change_set in change_sets + for change in change_set ), *( - change_set_listener(change_sets) + change_set_listener(change_set) for change_set_listener in self.change_set_listeners ), ) @@ -201,7 +201,7 @@ class YamlCollection(ObservableCollection[dict]): """Load the YAML collection. Overrides existing data.""" old_ids = set(self.data) - change_sets = [] + change_set = [] for item in data: item_id = item[CONF_ID] @@ -216,15 +216,15 @@ class YamlCollection(ObservableCollection[dict]): event = CHANGE_ADDED self.data[item_id] = item - change_sets.append(CollectionChangeSet(event, item_id, item)) + change_set.append(CollectionChange(event, item_id, item)) - change_sets.extend( - CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id)) + change_set.extend( + CollectionChange(CHANGE_REMOVED, item_id, self.data.pop(item_id)) for item_id in old_ids ) - if change_sets: - await self.notify_changes(change_sets) + if change_set: + await self.notify_changes(change_set) class SerializedStorageCollection(TypedDict): @@ -273,7 +273,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection]( await self.notify_changes( [ - CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item) + CollectionChange(CHANGE_ADDED, item[CONF_ID], item) for item in raw_storage["items"] ] ) @@ -313,7 +313,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection]( item = self._create_item(item_id, validated_data) self.data[item_id] = item self._async_schedule_save() - await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)]) + await self.notify_changes([CollectionChange(CHANGE_ADDED, item_id, item)]) return item async def async_update_item(self, item_id: str, updates: dict) -> _ItemT: @@ -331,9 +331,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection]( self.data[item_id] = updated self._async_schedule_save() - await self.notify_changes( - [CollectionChangeSet(CHANGE_UPDATED, item_id, updated)] - ) + await self.notify_changes([CollectionChange(CHANGE_UPDATED, item_id, updated)]) return self.data[item_id] @@ -345,7 +343,7 @@ class StorageCollection[_ItemT, _StoreT: SerializedStorageCollection]( item = self.data.pop(item_id) self._async_schedule_save() - await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)]) + await self.notify_changes([CollectionChange(CHANGE_REMOVED, item_id, item)]) @callback def _async_schedule_save(self) -> None: @@ -398,7 +396,7 @@ class IDLessCollection(YamlCollection): """Load the collection. Overrides existing data.""" await self.notify_changes( [ - CollectionChangeSet(CHANGE_REMOVED, item_id, item) + CollectionChange(CHANGE_REMOVED, item_id, item) for item_id, item in list(self.data.items()) ] ) @@ -413,7 +411,7 @@ class IDLessCollection(YamlCollection): await self.notify_changes( [ - CollectionChangeSet(CHANGE_ADDED, item_id, item) + CollectionChange(CHANGE_ADDED, item_id, item) for item_id, item in self.data.items() ] ) @@ -444,14 +442,14 @@ class _CollectionLifeCycle(Generic[_EntityT]): self.entities.pop(item_id, None) @callback - def _add_entity(self, change_set: CollectionChangeSet) -> CollectionEntity: + def _add_entity(self, change_set: CollectionChange) -> CollectionEntity: item_id = change_set.item_id entity = self.collection.create_entity(self.entity_class, change_set.item) self.entities[item_id] = entity entity.async_on_remove(partial(self._entity_removed, item_id)) return entity - async def _remove_entity(self, change_set: CollectionChangeSet) -> None: + async def _remove_entity(self, change_set: CollectionChange) -> None: item_id = change_set.item_id ent_reg = self.ent_reg entities = self.entities @@ -464,29 +462,27 @@ class _CollectionLifeCycle(Generic[_EntityT]): # the entity registry event handled by Entity._async_registry_updated entities.pop(item_id, None) - async def _update_entity(self, change_set: CollectionChangeSet) -> None: + async def _update_entity(self, change_set: CollectionChange) -> None: if entity := self.entities.get(change_set.item_id): await entity.async_update_config(change_set.item) - async def _collection_changed( - self, change_sets: Iterable[CollectionChangeSet] - ) -> None: + async def _collection_changed(self, change_set: Iterable[CollectionChange]) -> None: """Handle a collection change.""" # Create a new bucket every time we have a different change type # to ensure operations happen in order. We only group # the same change type. new_entities: list[CollectionEntity] = [] coros: list[Coroutine[Any, Any, CollectionEntity | None]] = [] - grouped: Iterable[CollectionChangeSet] - for _, grouped in groupby(change_sets, _GROUP_BY_KEY): - for change_set in grouped: - change_type = change_set.change_type + grouped: Iterable[CollectionChange] + for _, grouped in groupby(change_set, _GROUP_BY_KEY): + for change in grouped: + change_type = change.change_type if change_type == CHANGE_ADDED: - new_entities.append(self._add_entity(change_set)) + new_entities.append(self._add_entity(change)) elif change_type == CHANGE_REMOVED: - coros.append(self._remove_entity(change_set)) + coros.append(self._remove_entity(change)) elif change_type == CHANGE_UPDATED: - coros.append(self._update_entity(change_set)) + coros.append(self._update_entity(change)) if coros: await asyncio.gather(*coros) diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index 4be372efe9c..dc9ac21e246 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -124,7 +124,7 @@ async def test_observable_collection() -> None: changes = track_changes(coll) await coll.notify_changes( - [collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})] + [collection.CollectionChange("mock_type", "mock_id", {"mock": "item"})] ) assert len(changes) == 1 assert changes[0] == ("mock_type", "mock_id", {"mock": "item"}) @@ -263,7 +263,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None: await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_ADDED, "mock_id", {"id": "mock_id", "state": "initial", "name": "Mock 1"}, @@ -276,7 +276,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None: await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_UPDATED, "mock_id", {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, @@ -288,7 +288,7 @@ async def test_attach_entity_component_collection(hass: HomeAssistant) -> None: assert hass.states.get("test.mock_1").state == "second" await coll.notify_changes( - [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)], ) assert hass.states.get("test.mock_1") is None @@ -331,7 +331,7 @@ async def test_entity_component_collection_abort( await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_ADDED, "mock_id", {"id": "mock_id", "state": "initial", "name": "Mock 1"}, @@ -343,7 +343,7 @@ async def test_entity_component_collection_abort( await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_UPDATED, "mock_id", {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, @@ -355,7 +355,7 @@ async def test_entity_component_collection_abort( assert len(async_update_config_calls) == 0 await coll.notify_changes( - [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)], ) assert hass.states.get("test.mock_1") is None @@ -395,7 +395,7 @@ async def test_entity_component_collection_entity_removed( await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_ADDED, "mock_id", {"id": "mock_id", "state": "initial", "name": "Mock 1"}, @@ -413,7 +413,7 @@ async def test_entity_component_collection_entity_removed( await coll.notify_changes( [ - collection.CollectionChangeSet( + collection.CollectionChange( collection.CHANGE_UPDATED, "mock_id", {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, @@ -425,7 +425,7 @@ async def test_entity_component_collection_entity_removed( assert len(async_update_config_calls) == 0 await coll.notify_changes( - [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + [collection.CollectionChange(collection.CHANGE_REMOVED, "mock_id", None)], ) assert hass.states.get("test.mock_1") is None