Handle removed entites in collection.sync_entity_lifecycle (#70759)

* Handle removed entites in collection.sync_entity_lifecycle

* Add comment
This commit is contained in:
Erik Montnemery 2022-04-27 17:05:00 +02:00 committed by GitHub
parent 8a13c6744a
commit c5d69ab1b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 168 additions and 9 deletions

View File

@ -334,7 +334,13 @@ def sync_entity_lifecycle(
ent_reg = entity_registry.async_get(hass)
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
def entity_removed() -> None:
"""Remove entity from entities if it's removed or not added."""
if change_set.item_id in entities:
entities.pop(change_set.item_id)
entities[change_set.item_id] = create_entity(change_set.item)
entities[change_set.item_id].async_on_remove(entity_removed)
return entities[change_set.item_id]
async def _remove_entity(change_set: CollectionChangeSet) -> None:
@ -343,11 +349,16 @@ def sync_entity_lifecycle(
)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
else:
elif change_set.item_id in entities:
await entities[change_set.item_id].async_remove(force_remove=True)
entities.pop(change_set.item_id)
# Unconditionally pop the entity from the entity list to avoid racing against
# the entity registry event handled by Entity._async_registry_updated
if change_set.item_id in entities:
entities.pop(change_set.item_id)
async def _update_entity(change_set: CollectionChangeSet) -> None:
if change_set.item_id not in entities:
return
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore[attr-defined]
_func_map: dict[

View File

@ -759,7 +759,7 @@ class Entity(ABC):
@callback
def async_on_remove(self, func: CALLBACK_TYPE) -> None:
"""Add a function to call when entity removed."""
"""Add a function to call when entity is removed or not added."""
if self._on_remove is None:
self._on_remove = []
self._on_remove.append(func)
@ -788,13 +788,23 @@ class Entity(ABC):
self.parallel_updates = parallel_updates
self._platform_state = EntityPlatformState.ADDED
def _call_on_remove_callbacks(self) -> None:
"""Call callbacks registered by async_on_remove."""
if self._on_remove is None:
return
while self._on_remove:
self._on_remove.pop()()
@callback
def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform."""
self._platform_state = EntityPlatformState.NOT_ADDED
self._call_on_remove_callbacks()
self.hass = None # type: ignore[assignment]
self.platform = None
self.parallel_updates = None
self._platform_state = EntityPlatformState.NOT_ADDED
async def add_to_platform_finish(self) -> None:
"""Finish adding an entity to a platform."""
@ -819,9 +829,7 @@ class Entity(ABC):
self._platform_state = EntityPlatformState.REMOVED
if self._on_remove is not None:
while self._on_remove:
self._on_remove.pop()()
self._call_on_remove_callbacks()
await self.async_internal_will_remove_from_hass()
await self.async_will_remove_from_hass()

View File

@ -611,7 +611,7 @@ class EntityPlatform:
self.hass.states.async_reserve(entity.entity_id)
def remove_entity_cb() -> None:
"""Remove entity from entities list."""
"""Remove entity from entities dict."""
self.entities.pop(entity_id)
entity.async_on_remove(remove_entity_cb)

View File

@ -4,7 +4,13 @@ import logging
import pytest
import voluptuous as vol
from homeassistant.helpers import collection, entity, entity_component, storage
from homeassistant.helpers import (
collection,
entity,
entity_component,
entity_registry as er,
storage,
)
from tests.common import flush_store
@ -261,6 +267,140 @@ async def test_attach_entity_component_collection(hass):
assert hass.states.get("test.mock_1") is None
async def test_entity_component_collection_abort(hass):
"""Test aborted entity adding is handled."""
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
coll = collection.ObservableCollection(_LOGGER)
async_update_config_calls = []
async_remove_calls = []
class MockMockEntity(MockEntity):
"""Track calls to async_update_config and async_remove."""
async def async_update_config(self, config):
nonlocal async_update_config_calls
async_update_config_calls.append(None)
await super().async_update_config()
async def async_remove(self, *, force_remove: bool = False):
nonlocal async_remove_calls
async_remove_calls.append(None)
await super().async_remove()
collection.sync_entity_lifecycle(
hass, "test", "test", ent_comp, coll, MockMockEntity
)
entity_registry = er.async_get(hass)
entity_registry.async_get_or_create(
"test",
"test",
"mock_id",
suggested_object_id="mock_1",
disabled_by=er.RegistryEntryDisabler.INTEGRATION,
)
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
],
)
assert hass.states.get("test.mock_1") is None
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
],
)
assert hass.states.get("test.mock_1") is None
assert len(async_update_config_calls) == 0
await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
)
assert hass.states.get("test.mock_1") is None
assert len(async_remove_calls) == 0
async def test_entity_component_collection_entity_removed(hass):
"""Test entity removal is handled."""
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
coll = collection.ObservableCollection(_LOGGER)
async_update_config_calls = []
async_remove_calls = []
class MockMockEntity(MockEntity):
"""Track calls to async_update_config and async_remove."""
async def async_update_config(self, config):
nonlocal async_update_config_calls
async_update_config_calls.append(None)
await super().async_update_config()
async def async_remove(self, *, force_remove: bool = False):
nonlocal async_remove_calls
async_remove_calls.append(None)
await super().async_remove()
collection.sync_entity_lifecycle(
hass, "test", "test", ent_comp, coll, MockMockEntity
)
entity_registry = er.async_get(hass)
entity_registry.async_get_or_create(
"test", "test", "mock_id", suggested_object_id="mock_1"
)
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
],
)
assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial"
entity_registry.async_remove("test.mock_1")
await hass.async_block_till_done()
assert hass.states.get("test.mock_1") is None
assert len(async_remove_calls) == 1
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
],
)
assert hass.states.get("test.mock_1") is None
assert len(async_update_config_calls) == 0
await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
)
assert hass.states.get("test.mock_1") is None
assert len(async_remove_calls) == 1
async def test_storage_collection_websocket(hass, hass_ws_client):
"""Test exposing a storage collection via websockets."""
store = storage.Store(hass, 1, "test-data")