diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 9017c60c23f..7a73f90539c 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -334,7 +334,13 @@ def sync_entity_lifecycle( ent_reg = entity_registry.async_get(hass) async def _add_entity(change_set: CollectionChangeSet) -> Entity: + def entity_removed() -> None: + """Remove entity from entities if it's removed or not added.""" + if change_set.item_id in entities: + entities.pop(change_set.item_id) + entities[change_set.item_id] = create_entity(change_set.item) + entities[change_set.item_id].async_on_remove(entity_removed) return entities[change_set.item_id] async def _remove_entity(change_set: CollectionChangeSet) -> None: @@ -343,11 +349,16 @@ def sync_entity_lifecycle( ) if ent_to_remove is not None: ent_reg.async_remove(ent_to_remove) - else: + elif change_set.item_id in entities: await entities[change_set.item_id].async_remove(force_remove=True) - entities.pop(change_set.item_id) + # Unconditionally pop the entity from the entity list to avoid racing against + # the entity registry event handled by Entity._async_registry_updated + if change_set.item_id in entities: + entities.pop(change_set.item_id) async def _update_entity(change_set: CollectionChangeSet) -> None: + if change_set.item_id not in entities: + return await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore[attr-defined] _func_map: dict[ diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index f94b4257d30..791a80f7731 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -759,7 +759,7 @@ class Entity(ABC): @callback def async_on_remove(self, func: CALLBACK_TYPE) -> None: - """Add a function to call when entity removed.""" + """Add a function to call when entity is removed or not added.""" if self._on_remove is None: self._on_remove = [] self._on_remove.append(func) @@ -788,13 +788,23 @@ class Entity(ABC): self.parallel_updates = parallel_updates self._platform_state = EntityPlatformState.ADDED + def _call_on_remove_callbacks(self) -> None: + """Call callbacks registered by async_on_remove.""" + if self._on_remove is None: + return + while self._on_remove: + self._on_remove.pop()() + @callback def add_to_platform_abort(self) -> None: """Abort adding an entity to a platform.""" + + self._platform_state = EntityPlatformState.NOT_ADDED + self._call_on_remove_callbacks() + self.hass = None # type: ignore[assignment] self.platform = None self.parallel_updates = None - self._platform_state = EntityPlatformState.NOT_ADDED async def add_to_platform_finish(self) -> None: """Finish adding an entity to a platform.""" @@ -819,9 +829,7 @@ class Entity(ABC): self._platform_state = EntityPlatformState.REMOVED - if self._on_remove is not None: - while self._on_remove: - self._on_remove.pop()() + self._call_on_remove_callbacks() await self.async_internal_will_remove_from_hass() await self.async_will_remove_from_hass() diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 65cfe706f14..6972dbf7c16 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -611,7 +611,7 @@ class EntityPlatform: self.hass.states.async_reserve(entity.entity_id) def remove_entity_cb() -> None: - """Remove entity from entities list.""" + """Remove entity from entities dict.""" self.entities.pop(entity_id) entity.async_on_remove(remove_entity_cb) diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index 11ab0f46ce4..cd4bbba1a3e 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -4,7 +4,13 @@ import logging import pytest import voluptuous as vol -from homeassistant.helpers import collection, entity, entity_component, storage +from homeassistant.helpers import ( + collection, + entity, + entity_component, + entity_registry as er, + storage, +) from tests.common import flush_store @@ -261,6 +267,140 @@ async def test_attach_entity_component_collection(hass): assert hass.states.get("test.mock_1") is None +async def test_entity_component_collection_abort(hass): + """Test aborted entity adding is handled.""" + ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass) + coll = collection.ObservableCollection(_LOGGER) + + async_update_config_calls = [] + async_remove_calls = [] + + class MockMockEntity(MockEntity): + """Track calls to async_update_config and async_remove.""" + + async def async_update_config(self, config): + nonlocal async_update_config_calls + async_update_config_calls.append(None) + await super().async_update_config() + + async def async_remove(self, *, force_remove: bool = False): + nonlocal async_remove_calls + async_remove_calls.append(None) + await super().async_remove() + + collection.sync_entity_lifecycle( + hass, "test", "test", ent_comp, coll, MockMockEntity + ) + entity_registry = er.async_get(hass) + entity_registry.async_get_or_create( + "test", + "test", + "mock_id", + suggested_object_id="mock_1", + disabled_by=er.RegistryEntryDisabler.INTEGRATION, + ) + + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_ADDED, + "mock_id", + {"id": "mock_id", "state": "initial", "name": "Mock 1"}, + ) + ], + ) + + assert hass.states.get("test.mock_1") is None + + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_UPDATED, + "mock_id", + {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, + ) + ], + ) + + assert hass.states.get("test.mock_1") is None + assert len(async_update_config_calls) == 0 + + await coll.notify_changes( + [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + ) + + assert hass.states.get("test.mock_1") is None + assert len(async_remove_calls) == 0 + + +async def test_entity_component_collection_entity_removed(hass): + """Test entity removal is handled.""" + ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass) + coll = collection.ObservableCollection(_LOGGER) + + async_update_config_calls = [] + async_remove_calls = [] + + class MockMockEntity(MockEntity): + """Track calls to async_update_config and async_remove.""" + + async def async_update_config(self, config): + nonlocal async_update_config_calls + async_update_config_calls.append(None) + await super().async_update_config() + + async def async_remove(self, *, force_remove: bool = False): + nonlocal async_remove_calls + async_remove_calls.append(None) + await super().async_remove() + + collection.sync_entity_lifecycle( + hass, "test", "test", ent_comp, coll, MockMockEntity + ) + entity_registry = er.async_get(hass) + entity_registry.async_get_or_create( + "test", "test", "mock_id", suggested_object_id="mock_1" + ) + + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_ADDED, + "mock_id", + {"id": "mock_id", "state": "initial", "name": "Mock 1"}, + ) + ], + ) + + assert hass.states.get("test.mock_1").name == "Mock 1" + assert hass.states.get("test.mock_1").state == "initial" + + entity_registry.async_remove("test.mock_1") + await hass.async_block_till_done() + assert hass.states.get("test.mock_1") is None + assert len(async_remove_calls) == 1 + + await coll.notify_changes( + [ + collection.CollectionChangeSet( + collection.CHANGE_UPDATED, + "mock_id", + {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, + ) + ], + ) + + assert hass.states.get("test.mock_1") is None + assert len(async_update_config_calls) == 0 + + await coll.notify_changes( + [collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)], + ) + + assert hass.states.get("test.mock_1") is None + assert len(async_remove_calls) == 1 + + async def test_storage_collection_websocket(hass, hass_ws_client): """Test exposing a storage collection via websockets.""" store = storage.Store(hass, 1, "test-data")