mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 06:47:09 +00:00
Handle removed entites in collection.sync_entity_lifecycle (#70759)
* Handle removed entites in collection.sync_entity_lifecycle * Add comment
This commit is contained in:
parent
8a13c6744a
commit
c5d69ab1b2
@ -334,7 +334,13 @@ def sync_entity_lifecycle(
|
|||||||
ent_reg = entity_registry.async_get(hass)
|
ent_reg = entity_registry.async_get(hass)
|
||||||
|
|
||||||
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
|
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
|
||||||
|
def entity_removed() -> None:
|
||||||
|
"""Remove entity from entities if it's removed or not added."""
|
||||||
|
if change_set.item_id in entities:
|
||||||
|
entities.pop(change_set.item_id)
|
||||||
|
|
||||||
entities[change_set.item_id] = create_entity(change_set.item)
|
entities[change_set.item_id] = create_entity(change_set.item)
|
||||||
|
entities[change_set.item_id].async_on_remove(entity_removed)
|
||||||
return entities[change_set.item_id]
|
return entities[change_set.item_id]
|
||||||
|
|
||||||
async def _remove_entity(change_set: CollectionChangeSet) -> None:
|
async def _remove_entity(change_set: CollectionChangeSet) -> None:
|
||||||
@ -343,11 +349,16 @@ def sync_entity_lifecycle(
|
|||||||
)
|
)
|
||||||
if ent_to_remove is not None:
|
if ent_to_remove is not None:
|
||||||
ent_reg.async_remove(ent_to_remove)
|
ent_reg.async_remove(ent_to_remove)
|
||||||
else:
|
elif change_set.item_id in entities:
|
||||||
await entities[change_set.item_id].async_remove(force_remove=True)
|
await entities[change_set.item_id].async_remove(force_remove=True)
|
||||||
entities.pop(change_set.item_id)
|
# Unconditionally pop the entity from the entity list to avoid racing against
|
||||||
|
# the entity registry event handled by Entity._async_registry_updated
|
||||||
|
if change_set.item_id in entities:
|
||||||
|
entities.pop(change_set.item_id)
|
||||||
|
|
||||||
async def _update_entity(change_set: CollectionChangeSet) -> None:
|
async def _update_entity(change_set: CollectionChangeSet) -> None:
|
||||||
|
if change_set.item_id not in entities:
|
||||||
|
return
|
||||||
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore[attr-defined]
|
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore[attr-defined]
|
||||||
|
|
||||||
_func_map: dict[
|
_func_map: dict[
|
||||||
|
@ -759,7 +759,7 @@ class Entity(ABC):
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_on_remove(self, func: CALLBACK_TYPE) -> None:
|
def async_on_remove(self, func: CALLBACK_TYPE) -> None:
|
||||||
"""Add a function to call when entity removed."""
|
"""Add a function to call when entity is removed or not added."""
|
||||||
if self._on_remove is None:
|
if self._on_remove is None:
|
||||||
self._on_remove = []
|
self._on_remove = []
|
||||||
self._on_remove.append(func)
|
self._on_remove.append(func)
|
||||||
@ -788,13 +788,23 @@ class Entity(ABC):
|
|||||||
self.parallel_updates = parallel_updates
|
self.parallel_updates = parallel_updates
|
||||||
self._platform_state = EntityPlatformState.ADDED
|
self._platform_state = EntityPlatformState.ADDED
|
||||||
|
|
||||||
|
def _call_on_remove_callbacks(self) -> None:
|
||||||
|
"""Call callbacks registered by async_on_remove."""
|
||||||
|
if self._on_remove is None:
|
||||||
|
return
|
||||||
|
while self._on_remove:
|
||||||
|
self._on_remove.pop()()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def add_to_platform_abort(self) -> None:
|
def add_to_platform_abort(self) -> None:
|
||||||
"""Abort adding an entity to a platform."""
|
"""Abort adding an entity to a platform."""
|
||||||
|
|
||||||
|
self._platform_state = EntityPlatformState.NOT_ADDED
|
||||||
|
self._call_on_remove_callbacks()
|
||||||
|
|
||||||
self.hass = None # type: ignore[assignment]
|
self.hass = None # type: ignore[assignment]
|
||||||
self.platform = None
|
self.platform = None
|
||||||
self.parallel_updates = None
|
self.parallel_updates = None
|
||||||
self._platform_state = EntityPlatformState.NOT_ADDED
|
|
||||||
|
|
||||||
async def add_to_platform_finish(self) -> None:
|
async def add_to_platform_finish(self) -> None:
|
||||||
"""Finish adding an entity to a platform."""
|
"""Finish adding an entity to a platform."""
|
||||||
@ -819,9 +829,7 @@ class Entity(ABC):
|
|||||||
|
|
||||||
self._platform_state = EntityPlatformState.REMOVED
|
self._platform_state = EntityPlatformState.REMOVED
|
||||||
|
|
||||||
if self._on_remove is not None:
|
self._call_on_remove_callbacks()
|
||||||
while self._on_remove:
|
|
||||||
self._on_remove.pop()()
|
|
||||||
|
|
||||||
await self.async_internal_will_remove_from_hass()
|
await self.async_internal_will_remove_from_hass()
|
||||||
await self.async_will_remove_from_hass()
|
await self.async_will_remove_from_hass()
|
||||||
|
@ -611,7 +611,7 @@ class EntityPlatform:
|
|||||||
self.hass.states.async_reserve(entity.entity_id)
|
self.hass.states.async_reserve(entity.entity_id)
|
||||||
|
|
||||||
def remove_entity_cb() -> None:
|
def remove_entity_cb() -> None:
|
||||||
"""Remove entity from entities list."""
|
"""Remove entity from entities dict."""
|
||||||
self.entities.pop(entity_id)
|
self.entities.pop(entity_id)
|
||||||
|
|
||||||
entity.async_on_remove(remove_entity_cb)
|
entity.async_on_remove(remove_entity_cb)
|
||||||
|
@ -4,7 +4,13 @@ import logging
|
|||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.helpers import collection, entity, entity_component, storage
|
from homeassistant.helpers import (
|
||||||
|
collection,
|
||||||
|
entity,
|
||||||
|
entity_component,
|
||||||
|
entity_registry as er,
|
||||||
|
storage,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import flush_store
|
from tests.common import flush_store
|
||||||
|
|
||||||
@ -261,6 +267,140 @@ async def test_attach_entity_component_collection(hass):
|
|||||||
assert hass.states.get("test.mock_1") is None
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entity_component_collection_abort(hass):
|
||||||
|
"""Test aborted entity adding is handled."""
|
||||||
|
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
|
||||||
|
coll = collection.ObservableCollection(_LOGGER)
|
||||||
|
|
||||||
|
async_update_config_calls = []
|
||||||
|
async_remove_calls = []
|
||||||
|
|
||||||
|
class MockMockEntity(MockEntity):
|
||||||
|
"""Track calls to async_update_config and async_remove."""
|
||||||
|
|
||||||
|
async def async_update_config(self, config):
|
||||||
|
nonlocal async_update_config_calls
|
||||||
|
async_update_config_calls.append(None)
|
||||||
|
await super().async_update_config()
|
||||||
|
|
||||||
|
async def async_remove(self, *, force_remove: bool = False):
|
||||||
|
nonlocal async_remove_calls
|
||||||
|
async_remove_calls.append(None)
|
||||||
|
await super().async_remove()
|
||||||
|
|
||||||
|
collection.sync_entity_lifecycle(
|
||||||
|
hass, "test", "test", ent_comp, coll, MockMockEntity
|
||||||
|
)
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"test",
|
||||||
|
"test",
|
||||||
|
"mock_id",
|
||||||
|
suggested_object_id="mock_1",
|
||||||
|
disabled_by=er.RegistryEntryDisabler.INTEGRATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[
|
||||||
|
collection.CollectionChangeSet(
|
||||||
|
collection.CHANGE_ADDED,
|
||||||
|
"mock_id",
|
||||||
|
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[
|
||||||
|
collection.CollectionChangeSet(
|
||||||
|
collection.CHANGE_UPDATED,
|
||||||
|
"mock_id",
|
||||||
|
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
assert len(async_update_config_calls) == 0
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
assert len(async_remove_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entity_component_collection_entity_removed(hass):
|
||||||
|
"""Test entity removal is handled."""
|
||||||
|
ent_comp = entity_component.EntityComponent(_LOGGER, "test", hass)
|
||||||
|
coll = collection.ObservableCollection(_LOGGER)
|
||||||
|
|
||||||
|
async_update_config_calls = []
|
||||||
|
async_remove_calls = []
|
||||||
|
|
||||||
|
class MockMockEntity(MockEntity):
|
||||||
|
"""Track calls to async_update_config and async_remove."""
|
||||||
|
|
||||||
|
async def async_update_config(self, config):
|
||||||
|
nonlocal async_update_config_calls
|
||||||
|
async_update_config_calls.append(None)
|
||||||
|
await super().async_update_config()
|
||||||
|
|
||||||
|
async def async_remove(self, *, force_remove: bool = False):
|
||||||
|
nonlocal async_remove_calls
|
||||||
|
async_remove_calls.append(None)
|
||||||
|
await super().async_remove()
|
||||||
|
|
||||||
|
collection.sync_entity_lifecycle(
|
||||||
|
hass, "test", "test", ent_comp, coll, MockMockEntity
|
||||||
|
)
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"test", "test", "mock_id", suggested_object_id="mock_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[
|
||||||
|
collection.CollectionChangeSet(
|
||||||
|
collection.CHANGE_ADDED,
|
||||||
|
"mock_id",
|
||||||
|
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1").name == "Mock 1"
|
||||||
|
assert hass.states.get("test.mock_1").state == "initial"
|
||||||
|
|
||||||
|
entity_registry.async_remove("test.mock_1")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
assert len(async_remove_calls) == 1
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[
|
||||||
|
collection.CollectionChangeSet(
|
||||||
|
collection.CHANGE_UPDATED,
|
||||||
|
"mock_id",
|
||||||
|
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
assert len(async_update_config_calls) == 0
|
||||||
|
|
||||||
|
await coll.notify_changes(
|
||||||
|
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hass.states.get("test.mock_1") is None
|
||||||
|
assert len(async_remove_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_storage_collection_websocket(hass, hass_ws_client):
|
async def test_storage_collection_websocket(hass, hass_ws_client):
|
||||||
"""Test exposing a storage collection via websockets."""
|
"""Test exposing a storage collection via websockets."""
|
||||||
store = storage.Store(hass, 1, "test-data")
|
store = storage.Store(hass, 1, "test-data")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user