mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Handle removed entites in collection.sync_entity_lifecycle (#70759)
* Handle removed entites in collection.sync_entity_lifecycle * Add comment
This commit is contained in:
parent
8a13c6744a
commit
c5d69ab1b2
@ -334,7 +334,13 @@ def sync_entity_lifecycle(
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
|
||||
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
|
||||
def entity_removed() -> None:
|
||||
"""Remove entity from entities if it's removed or not added."""
|
||||
if change_set.item_id in entities:
|
||||
entities.pop(change_set.item_id)
|
||||
|
||||
entities[change_set.item_id] = create_entity(change_set.item)
|
||||
entities[change_set.item_id].async_on_remove(entity_removed)
|
||||
return entities[change_set.item_id]
|
||||
|
||||
async def _remove_entity(change_set: CollectionChangeSet) -> None:
|
||||
@ -343,11 +349,16 @@ def sync_entity_lifecycle(
|
||||
)
|
||||
if ent_to_remove is not None:
|
||||
ent_reg.async_remove(ent_to_remove)
|
||||
else:
|
||||
elif change_set.item_id in entities:
|
||||
await entities[change_set.item_id].async_remove(force_remove=True)
|
||||
entities.pop(change_set.item_id)
|
||||
# Unconditionally pop the entity from the entity list to avoid racing against
|
||||
# the entity registry event handled by Entity._async_registry_updated
|
||||
if change_set.item_id in entities:
|
||||
entities.pop(change_set.item_id)
|
||||
|
||||
async def _update_entity(change_set: CollectionChangeSet) -> None:
|
||||
if change_set.item_id not in entities:
|
||||
return
|
||||
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore[attr-defined]
|
||||
|
||||
_func_map: dict[
|
||||
|
@ -759,7 +759,7 @@ class Entity(ABC):
|
||||
|
||||
@callback
|
||||
def async_on_remove(self, func: CALLBACK_TYPE) -> None:
|
||||
"""Add a function to call when entity removed."""
|
||||
"""Add a function to call when entity is removed or not added."""
|
||||
if self._on_remove is None:
|
||||
self._on_remove = []
|
||||
self._on_remove.append(func)
|
||||
@ -788,13 +788,23 @@ class Entity(ABC):
|
||||
self.parallel_updates = parallel_updates
|
||||
self._platform_state = EntityPlatformState.ADDED
|
||||
|
||||
def _call_on_remove_callbacks(self) -> None:
|
||||
"""Call callbacks registered by async_on_remove."""
|
||||
if self._on_remove is None:
|
||||
return
|
||||
while self._on_remove:
|
||||
self._on_remove.pop()()
|
||||
|
||||
@callback
|
||||
def add_to_platform_abort(self) -> None:
|
||||
"""Abort adding an entity to a platform."""
|
||||
|
||||
self._platform_state = EntityPlatformState.NOT_ADDED
|
||||
self._call_on_remove_callbacks()
|
||||
|
||||
self.hass = None # type: ignore[assignment]
|
||||
self.platform = None
|
||||
self.parallel_updates = None
|
||||
self._platform_state = EntityPlatformState.NOT_ADDED
|
||||
|
||||
async def add_to_platform_finish(self) -> None:
|
||||
"""Finish adding an entity to a platform."""
|
||||
@ -819,9 +829,7 @@ class Entity(ABC):
|
||||
|
||||
self._platform_state = EntityPlatformState.REMOVED
|
||||
|
||||
if self._on_remove is not None:
|
||||
while self._on_remove:
|
||||
self._on_remove.pop()()
|
||||
self._call_on_remove_callbacks()
|
||||
|
||||
await self.async_internal_will_remove_from_hass()
|
||||
await self.async_will_remove_from_hass()
|
||||
|
@ -611,7 +611,7 @@ class EntityPlatform:
|
||||
self.hass.states.async_reserve(entity.entity_id)
|
||||
|
||||
def remove_entity_cb() -> None:
|
||||
"""Remove entity from entities list."""
|
||||
"""Remove entity from entities dict."""
|
||||
self.entities.pop(entity_id)
|
||||
|
||||
entity.async_on_remove(remove_entity_cb)
|
||||
|
@ -4,7 +4,13 @@ import logging
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers import collection, entity, entity_component, storage
|
||||
from homeassistant.helpers import (
|
||||
collection,
|
||||
entity,
|
||||
entity_component,
|
||||
entity_registry as er,
|
||||
storage,
|
||||
)
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
@ -261,6 +267,140 @@ async def test_attach_entity_component_collection(hass):
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
|
||||
|
||||
async def test_entity_component_collection_abort(hass):
|
||||
"""Test aborted entity adding is handled."""
|
||||
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
|
||||
coll = collection.ObservableCollection(_LOGGER)
|
||||
|
||||
async_update_config_calls = []
|
||||
async_remove_calls = []
|
||||
|
||||
class MockMockEntity(MockEntity):
|
||||
"""Track calls to async_update_config and async_remove."""
|
||||
|
||||
async def async_update_config(self, config):
|
||||
nonlocal async_update_config_calls
|
||||
async_update_config_calls.append(None)
|
||||
await super().async_update_config()
|
||||
|
||||
async def async_remove(self, *, force_remove: bool = False):
|
||||
nonlocal async_remove_calls
|
||||
async_remove_calls.append(None)
|
||||
await super().async_remove()
|
||||
|
||||
collection.sync_entity_lifecycle(
|
||||
hass, "test", "test", ent_comp, coll, MockMockEntity
|
||||
)
|
||||
entity_registry = er.async_get(hass)
|
||||
entity_registry.async_get_or_create(
|
||||
"test",
|
||||
"test",
|
||||
"mock_id",
|
||||
suggested_object_id="mock_1",
|
||||
disabled_by=er.RegistryEntryDisabler.INTEGRATION,
|
||||
)
|
||||
|
||||
await coll.notify_changes(
|
||||
[
|
||||
collection.CollectionChangeSet(
|
||||
collection.CHANGE_ADDED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
|
||||
await coll.notify_changes(
|
||||
[
|
||||
collection.CollectionChangeSet(
|
||||
collection.CHANGE_UPDATED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
assert len(async_update_config_calls) == 0
|
||||
|
||||
await coll.notify_changes(
|
||||
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
assert len(async_remove_calls) == 0
|
||||
|
||||
|
||||
async def test_entity_component_collection_entity_removed(hass):
|
||||
"""Test entity removal is handled."""
|
||||
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
|
||||
coll = collection.ObservableCollection(_LOGGER)
|
||||
|
||||
async_update_config_calls = []
|
||||
async_remove_calls = []
|
||||
|
||||
class MockMockEntity(MockEntity):
|
||||
"""Track calls to async_update_config and async_remove."""
|
||||
|
||||
async def async_update_config(self, config):
|
||||
nonlocal async_update_config_calls
|
||||
async_update_config_calls.append(None)
|
||||
await super().async_update_config()
|
||||
|
||||
async def async_remove(self, *, force_remove: bool = False):
|
||||
nonlocal async_remove_calls
|
||||
async_remove_calls.append(None)
|
||||
await super().async_remove()
|
||||
|
||||
collection.sync_entity_lifecycle(
|
||||
hass, "test", "test", ent_comp, coll, MockMockEntity
|
||||
)
|
||||
entity_registry = er.async_get(hass)
|
||||
entity_registry.async_get_or_create(
|
||||
"test", "test", "mock_id", suggested_object_id="mock_1"
|
||||
)
|
||||
|
||||
await coll.notify_changes(
|
||||
[
|
||||
collection.CollectionChangeSet(
|
||||
collection.CHANGE_ADDED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1").name == "Mock 1"
|
||||
assert hass.states.get("test.mock_1").state == "initial"
|
||||
|
||||
entity_registry.async_remove("test.mock_1")
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
assert len(async_remove_calls) == 1
|
||||
|
||||
await coll.notify_changes(
|
||||
[
|
||||
collection.CollectionChangeSet(
|
||||
collection.CHANGE_UPDATED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
assert len(async_update_config_calls) == 0
|
||||
|
||||
await coll.notify_changes(
|
||||
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
assert len(async_remove_calls) == 1
|
||||
|
||||
|
||||
async def test_storage_collection_websocket(hass, hass_ws_client):
|
||||
"""Test exposing a storage collection via websockets."""
|
||||
store = storage.Store(hass, 1, "test-data")
|
||||
|
Loading…
x
Reference in New Issue
Block a user