Prevent accidentally reusing an entity object (#102911)

* Prevent accidentally reusing an entity object

* Fix group reload service

* Revert "Fix group reload service"

* Improve test

* Add tests aserting entity can't be reused
This commit is contained in:
Erik Montnemery 2023-11-03 21:01:38 +01:00 committed by GitHub
parent dca72c598e
commit 0ea0a1ed06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 14 deletions

View File

@ -987,7 +987,7 @@ class Entity(ABC):
parallel_updates: asyncio.Semaphore | None, parallel_updates: asyncio.Semaphore | None,
) -> None: ) -> None:
"""Start adding an entity to a platform.""" """Start adding an entity to a platform."""
if self._platform_state == EntityPlatformState.ADDED: if self._platform_state != EntityPlatformState.NOT_ADDED:
raise HomeAssistantError( raise HomeAssistantError(
f"Entity '{self.entity_id}' cannot be added a second time to an entity" f"Entity '{self.entity_id}' cannot be added a second time to an entity"
" platform" " platform"
@ -1009,7 +1009,7 @@ class Entity(ABC):
def add_to_platform_abort(self) -> None: def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform.""" """Abort adding an entity to a platform."""
self._platform_state = EntityPlatformState.NOT_ADDED self._platform_state = EntityPlatformState.REMOVED
self._call_on_remove_callbacks() self._call_on_remove_callbacks()
self.hass = None # type: ignore[assignment] self.hass = None # type: ignore[assignment]
@ -1156,6 +1156,7 @@ class Entity(ABC):
await self.async_remove(force_remove=True) await self.async_remove(force_remove=True)
self.entity_id = registry_entry.entity_id self.entity_id = registry_entry.entity_id
self._platform_state = EntityPlatformState.NOT_ADDED
await self.platform.async_add_entities([self]) await self.platform.async_add_entities([self])
@callback @callback

View File

@ -1528,3 +1528,65 @@ async def test_suggest_report_issue_custom_component(
suggestion = mock_entity._suggest_report_issue() suggestion = mock_entity._suggest_report_issue()
assert suggestion == "create a bug report at https://some_url" assert suggestion == "create a bug report at https://some_url"
async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None:
"""Test reuse entity object."""
platform = MockEntityPlatform(hass, domain="test")
ent = entity.Entity()
ent.entity_id = "invalid"
with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"):
await platform.async_add_entities([ent])
with pytest.raises(
HomeAssistantError,
match="Entity invalid cannot be added a second time to an entity platform",
):
await platform.async_add_entities([ent])
async def test_reuse_entity_object_after_entity_registry_remove(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test reuse entity object."""
entry = entity_registry.async_get_or_create("test", "test", "5678")
platform = MockEntityPlatform(hass, domain="test", platform_name="test")
ent = entity.Entity()
ent._attr_unique_id = "5678"
await platform.async_add_entities([ent])
assert ent.registry_entry is entry
assert len(hass.states.async_entity_ids()) == 1
entity_registry.async_remove(entry.entity_id)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 0
with pytest.raises(
HomeAssistantError,
match="Entity test.test_5678 cannot be added a second time",
):
await platform.async_add_entities([ent])
async def test_reuse_entity_object_after_entity_registry_disabled(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test reuse entity object."""
entry = entity_registry.async_get_or_create("test", "test", "5678")
platform = MockEntityPlatform(hass, domain="test", platform_name="test")
ent = entity.Entity()
ent._attr_unique_id = "5678"
await platform.async_add_entities([ent])
assert ent.registry_entry is entry
assert len(hass.states.async_entity_ids()) == 1
entity_registry.async_update_entity(
entry.entity_id, disabled_by=er.RegistryEntryDisabler.USER
)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 0
with pytest.raises(
HomeAssistantError,
match="Entity test.test_5678 cannot be added a second time",
):
await platform.async_add_entities([ent])

View File

@ -565,24 +565,32 @@ async def test_async_remove_with_platform_update_finishes(hass: HomeAssistant) -
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({}) await component.async_setup({})
entity1 = MockEntity(name="test_1") entity1 = MockEntity(name="test_1")
entity2 = MockEntity(name="test_1")
async def _delayed_update(*args, **kwargs): async def _delayed_update(*args, **kwargs):
await asyncio.sleep(0.01) update_called.set()
await update_done.wait()
entity1.async_update = _delayed_update entity1.async_update = _delayed_update
entity2.async_update = _delayed_update
# Add, remove, add, remove and make sure no updates # Add, remove, and make sure no updates
# cause the entity to reappear after removal # cause the entity to reappear after removal and
for _ in range(2): # that we can add another entity with the same entity_id
await component.async_add_entities([entity1]) for entity in [entity1, entity2]:
assert len(hass.states.async_entity_ids()) == 1 update_called = asyncio.Event()
entity1.async_write_ha_state() update_done = asyncio.Event()
assert hass.states.get(entity1.entity_id) is not None await component.async_add_entities([entity])
task = asyncio.create_task(entity1.async_update_ha_state(True)) assert hass.states.async_entity_ids() == ["test_domain.test_1"]
await entity1.async_remove() entity.async_write_ha_state()
assert len(hass.states.async_entity_ids()) == 0 assert hass.states.get(entity.entity_id) is not None
task = asyncio.create_task(entity.async_update_ha_state(True))
await update_called.wait()
await entity.async_remove()
assert hass.states.async_entity_ids() == []
update_done.set()
await task await task
assert len(hass.states.async_entity_ids()) == 0 assert hass.states.async_entity_ids() == []
async def test_not_adding_duplicate_entities_with_unique_id( async def test_not_adding_duplicate_entities_with_unique_id(