Allow removing an entity more than once (#102904)

This commit is contained in:
Erik Montnemery 2023-11-08 12:50:40 +01:00 committed by GitHub
parent 44fe704f49
commit d913508607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 6 deletions

View File

@ -311,6 +311,8 @@ class Entity(ABC):
# and removes the need for constant None checks or asserts.
_state_info: StateInfo = None # type: ignore[assignment]
__remove_event: asyncio.Event | None = None
# Entity Properties
_attr_assumed_state: bool = False
_attr_attribution: str | None = None
@ -1022,6 +1024,7 @@ class Entity(ABC):
await self.async_added_to_hass()
self.async_write_ha_state()
@final
async def async_remove(self, *, force_remove: bool = False) -> None:
"""Remove entity from Home Assistant.
@ -1032,12 +1035,19 @@ class Entity(ABC):
If the entity doesn't have a non disabled entry in the entity registry,
or if force_remove=True, its state will be removed.
"""
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
if self.platform and self._platform_state != EntityPlatformState.ADDED:
raise HomeAssistantError(
f"Entity '{self.entity_id}' async_remove called twice"
)
if self.__remove_event is not None:
await self.__remove_event.wait()
return
self.__remove_event = asyncio.Event()
try:
await self.__async_remove_impl(force_remove)
finally:
self.__remove_event.set()
@final
async def __async_remove_impl(self, force_remove: bool) -> None:
"""Remove entity from Home Assistant."""
self._platform_state = EntityPlatformState.REMOVED
@ -1156,6 +1166,9 @@ class Entity(ABC):
await self.async_remove(force_remove=True)
self.entity_id = registry_entry.entity_id
# Clear the remove event to handle entity added again after entity id change
self.__remove_event = None
self._platform_state = EntityPlatformState.NOT_ADDED
await self.platform.async_add_entities([self])

View File

@ -621,6 +621,34 @@ async def test_async_remove_ignores_in_flight_polling(hass: HomeAssistant) -> No
assert hass.states.get("test.test") is None
async def test_async_remove_twice(hass: HomeAssistant) -> None:
"""Test removing an entity twice only cleans up once."""
result = []
class MockEntity(entity.Entity):
def __init__(self) -> None:
self.remove_calls = []
async def async_will_remove_from_hass(self):
self.remove_calls.append(None)
platform = MockEntityPlatform(hass, domain="test")
ent = MockEntity()
ent.hass = hass
ent.entity_id = "test.test"
ent.async_on_remove(lambda: result.append(1))
await platform.async_add_entities([ent])
assert hass.states.get("test.test").state == STATE_UNKNOWN
await ent.async_remove()
assert len(result) == 1
assert len(ent.remove_calls) == 1
await ent.async_remove()
assert len(result) == 1
assert len(ent.remove_calls) == 1
async def test_set_context(hass: HomeAssistant) -> None:
"""Test setting context."""
context = Context()
@ -1590,3 +1618,51 @@ async def test_reuse_entity_object_after_entity_registry_disabled(
match="Entity 'test.test_5678' cannot be added a second time",
):
await platform.async_add_entities([ent])
async def test_change_entity_id(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test changing entity id."""
result = []
entry = entity_registry.async_get_or_create(
"test", "test_platform", "5678", suggested_object_id="test"
)
assert entry.entity_id == "test.test"
class MockEntity(entity.Entity):
_attr_unique_id = "5678"
def __init__(self) -> None:
self.added_calls = []
self.remove_calls = []
async def async_added_to_hass(self):
self.added_calls.append(None)
self.async_on_remove(lambda: result.append(1))
async def async_will_remove_from_hass(self):
self.remove_calls.append(None)
platform = MockEntityPlatform(hass, domain="test")
ent = MockEntity()
await platform.async_add_entities([ent])
assert hass.states.get("test.test").state == STATE_UNKNOWN
assert len(ent.added_calls) == 1
entry = entity_registry.async_update_entity(
entry.entity_id, new_entity_id="test.test2"
)
await hass.async_block_till_done()
assert len(result) == 1
assert len(ent.added_calls) == 2
assert len(ent.remove_calls) == 1
entity_registry.async_update_entity(entry.entity_id, new_entity_id="test.test3")
await hass.async_block_till_done()
assert len(result) == 2
assert len(ent.added_calls) == 3
assert len(ent.remove_calls) == 2