diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 0948e1ef808..7877ca0e613 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -311,6 +311,8 @@ class Entity(ABC): # and removes the need for constant None checks or asserts. _state_info: StateInfo = None # type: ignore[assignment] + __remove_event: asyncio.Event | None = None + # Entity Properties _attr_assumed_state: bool = False _attr_attribution: str | None = None @@ -1022,6 +1024,7 @@ class Entity(ABC): await self.async_added_to_hass() self.async_write_ha_state() + @final async def async_remove(self, *, force_remove: bool = False) -> None: """Remove entity from Home Assistant. @@ -1032,12 +1035,19 @@ class Entity(ABC): If the entity doesn't have a non disabled entry in the entity registry, or if force_remove=True, its state will be removed. """ - # The check for self.platform guards against integrations not using an - # EntityComponent and can be removed in HA Core 2024.1 - if self.platform and self._platform_state != EntityPlatformState.ADDED: - raise HomeAssistantError( - f"Entity '{self.entity_id}' async_remove called twice" - ) + if self.__remove_event is not None: + await self.__remove_event.wait() + return + + self.__remove_event = asyncio.Event() + try: + await self.__async_remove_impl(force_remove) + finally: + self.__remove_event.set() + + @final + async def __async_remove_impl(self, force_remove: bool) -> None: + """Remove entity from Home Assistant.""" self._platform_state = EntityPlatformState.REMOVED @@ -1156,6 +1166,9 @@ class Entity(ABC): await self.async_remove(force_remove=True) self.entity_id = registry_entry.entity_id + + # Clear the remove event to handle entity added again after entity id change + self.__remove_event = None self._platform_state = EntityPlatformState.NOT_ADDED await self.platform.async_add_entities([self]) diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index a3ba5e48641..373dfac0cea 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -621,6 +621,34 @@ async def test_async_remove_ignores_in_flight_polling(hass: HomeAssistant) -> No assert hass.states.get("test.test") is None +async def test_async_remove_twice(hass: HomeAssistant) -> None: + """Test removing an entity twice only cleans up once.""" + result = [] + + class MockEntity(entity.Entity): + def __init__(self) -> None: + self.remove_calls = [] + + async def async_will_remove_from_hass(self): + self.remove_calls.append(None) + + platform = MockEntityPlatform(hass, domain="test") + ent = MockEntity() + ent.hass = hass + ent.entity_id = "test.test" + ent.async_on_remove(lambda: result.append(1)) + await platform.async_add_entities([ent]) + assert hass.states.get("test.test").state == STATE_UNKNOWN + + await ent.async_remove() + assert len(result) == 1 + assert len(ent.remove_calls) == 1 + + await ent.async_remove() + assert len(result) == 1 + assert len(ent.remove_calls) == 1 + + async def test_set_context(hass: HomeAssistant) -> None: """Test setting context.""" context = Context() @@ -1590,3 +1618,51 @@ async def test_reuse_entity_object_after_entity_registry_disabled( match="Entity 'test.test_5678' cannot be added a second time", ): await platform.async_add_entities([ent]) + + +async def test_change_entity_id( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test changing entity id.""" + result = [] + + entry = entity_registry.async_get_or_create( + "test", "test_platform", "5678", suggested_object_id="test" + ) + assert entry.entity_id == "test.test" + + class MockEntity(entity.Entity): + _attr_unique_id = "5678" + + def __init__(self) -> None: + self.added_calls = [] + self.remove_calls = [] + + async def async_added_to_hass(self): + self.added_calls.append(None) + self.async_on_remove(lambda: result.append(1)) + + async def async_will_remove_from_hass(self): + self.remove_calls.append(None) + + platform = MockEntityPlatform(hass, domain="test") + ent = MockEntity() + await platform.async_add_entities([ent]) + assert hass.states.get("test.test").state == STATE_UNKNOWN + assert len(ent.added_calls) == 1 + + entry = entity_registry.async_update_entity( + entry.entity_id, new_entity_id="test.test2" + ) + await hass.async_block_till_done() + + assert len(result) == 1 + assert len(ent.added_calls) == 2 + assert len(ent.remove_calls) == 1 + + entity_registry.async_update_entity(entry.entity_id, new_entity_id="test.test3") + await hass.async_block_till_done() + + assert len(result) == 2 + assert len(ent.added_calls) == 3 + assert len(ent.remove_calls) == 2