diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 207016aee86..0948e1ef808 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -987,7 +987,7 @@ class Entity(ABC): parallel_updates: asyncio.Semaphore | None, ) -> None: """Start adding an entity to a platform.""" - if self._platform_state == EntityPlatformState.ADDED: + if self._platform_state != EntityPlatformState.NOT_ADDED: raise HomeAssistantError( f"Entity '{self.entity_id}' cannot be added a second time to an entity" " platform" @@ -1009,7 +1009,7 @@ class Entity(ABC): def add_to_platform_abort(self) -> None: """Abort adding an entity to a platform.""" - self._platform_state = EntityPlatformState.NOT_ADDED + self._platform_state = EntityPlatformState.REMOVED self._call_on_remove_callbacks() self.hass = None # type: ignore[assignment] @@ -1156,6 +1156,7 @@ class Entity(ABC): await self.async_remove(force_remove=True) self.entity_id = registry_entry.entity_id + self._platform_state = EntityPlatformState.NOT_ADDED await self.platform.async_add_entities([self]) @callback diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index cf76083fe7a..26a4e48eb55 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -1528,3 +1528,65 @@ async def test_suggest_report_issue_custom_component( suggestion = mock_entity._suggest_report_issue() assert suggestion == "create a bug report at https://some_url" + + +async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None: + """Test reuse entity object.""" + platform = MockEntityPlatform(hass, domain="test") + ent = entity.Entity() + ent.entity_id = "invalid" + with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"): + await platform.async_add_entities([ent]) + with pytest.raises( + HomeAssistantError, + match="Entity invalid cannot be added a second time to an entity platform", + ): + await platform.async_add_entities([ent]) + + +async def test_reuse_entity_object_after_entity_registry_remove( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test reuse entity object.""" + entry = entity_registry.async_get_or_create("test", "test", "5678") + platform = MockEntityPlatform(hass, domain="test", platform_name="test") + ent = entity.Entity() + ent._attr_unique_id = "5678" + await platform.async_add_entities([ent]) + assert ent.registry_entry is entry + assert len(hass.states.async_entity_ids()) == 1 + + entity_registry.async_remove(entry.entity_id) + await hass.async_block_till_done() + assert len(hass.states.async_entity_ids()) == 0 + + with pytest.raises( + HomeAssistantError, + match="Entity test.test_5678 cannot be added a second time", + ): + await platform.async_add_entities([ent]) + + +async def test_reuse_entity_object_after_entity_registry_disabled( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test reuse entity object.""" + entry = entity_registry.async_get_or_create("test", "test", "5678") + platform = MockEntityPlatform(hass, domain="test", platform_name="test") + ent = entity.Entity() + ent._attr_unique_id = "5678" + await platform.async_add_entities([ent]) + assert ent.registry_entry is entry + assert len(hass.states.async_entity_ids()) == 1 + + entity_registry.async_update_entity( + entry.entity_id, disabled_by=er.RegistryEntryDisabler.USER + ) + await hass.async_block_till_done() + assert len(hass.states.async_entity_ids()) == 0 + + with pytest.raises( + HomeAssistantError, + match="Entity test.test_5678 cannot be added a second time", + ): + await platform.async_add_entities([ent]) diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 7ccbd5e0f28..af8fbf59049 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -565,24 +565,32 @@ async def test_async_remove_with_platform_update_finishes(hass: HomeAssistant) - component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_setup({}) entity1 = MockEntity(name="test_1") + entity2 = MockEntity(name="test_1") async def _delayed_update(*args, **kwargs): - await asyncio.sleep(0.01) + update_called.set() + await update_done.wait() entity1.async_update = _delayed_update + entity2.async_update = _delayed_update - # Add, remove, add, remove and make sure no updates - # cause the entity to reappear after removal - for _ in range(2): - await component.async_add_entities([entity1]) - assert len(hass.states.async_entity_ids()) == 1 - entity1.async_write_ha_state() - assert hass.states.get(entity1.entity_id) is not None - task = asyncio.create_task(entity1.async_update_ha_state(True)) - await entity1.async_remove() - assert len(hass.states.async_entity_ids()) == 0 + # Add, remove, and make sure no updates + # cause the entity to reappear after removal and + # that we can add another entity with the same entity_id + for entity in [entity1, entity2]: + update_called = asyncio.Event() + update_done = asyncio.Event() + await component.async_add_entities([entity]) + assert hass.states.async_entity_ids() == ["test_domain.test_1"] + entity.async_write_ha_state() + assert hass.states.get(entity.entity_id) is not None + task = asyncio.create_task(entity.async_update_ha_state(True)) + await update_called.wait() + await entity.async_remove() + assert hass.states.async_entity_ids() == [] + update_done.set() await task - assert len(hass.states.async_entity_ids()) == 0 + assert hass.states.async_entity_ids() == [] async def test_not_adding_duplicate_entities_with_unique_id(