mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Fix race in removing entities from the registry (#110978)
This commit is contained in:
parent
94e372a345
commit
9c145b5faa
@ -496,6 +496,9 @@ class Entity(
|
|||||||
# Entry in the entity registry
|
# Entry in the entity registry
|
||||||
registry_entry: er.RegistryEntry | None = None
|
registry_entry: er.RegistryEntry | None = None
|
||||||
|
|
||||||
|
# If the entity is removed from the entity registry
|
||||||
|
_removed_from_registry: bool = False
|
||||||
|
|
||||||
# The device entry for this entity
|
# The device entry for this entity
|
||||||
device_entry: dr.DeviceEntry | None = None
|
device_entry: dr.DeviceEntry | None = None
|
||||||
|
|
||||||
@ -1361,6 +1364,17 @@ class Entity(
|
|||||||
not force_remove
|
not force_remove
|
||||||
and self.registry_entry
|
and self.registry_entry
|
||||||
and not self.registry_entry.disabled
|
and not self.registry_entry.disabled
|
||||||
|
# Check if entity is still in the entity registry
|
||||||
|
# by checking self._removed_from_registry
|
||||||
|
#
|
||||||
|
# Because self.registry_entry is unset in a task,
|
||||||
|
# its possible that the entity has been removed but
|
||||||
|
# the task has not yet been executed.
|
||||||
|
#
|
||||||
|
# self._removed_from_registry is set to True in a
|
||||||
|
# callback which does not have the same issue.
|
||||||
|
#
|
||||||
|
and not self._removed_from_registry
|
||||||
):
|
):
|
||||||
# Set the entity's state will to unavailable + ATTR_RESTORED: True
|
# Set the entity's state will to unavailable + ATTR_RESTORED: True
|
||||||
self.registry_entry.write_unavailable_state(self.hass)
|
self.registry_entry.write_unavailable_state(self.hass)
|
||||||
@ -1430,10 +1444,23 @@ class Entity(
|
|||||||
if self.platform:
|
if self.platform:
|
||||||
self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id)
|
self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id)
|
||||||
|
|
||||||
async def _async_registry_updated(
|
@callback
|
||||||
|
def _async_registry_updated(
|
||||||
self, event: EventType[er.EventEntityRegistryUpdatedData]
|
self, event: EventType[er.EventEntityRegistryUpdatedData]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle entity registry update."""
|
"""Handle entity registry update."""
|
||||||
|
action = event.data["action"]
|
||||||
|
is_remove = action == "remove"
|
||||||
|
self._removed_from_registry = is_remove
|
||||||
|
if action == "update" or is_remove:
|
||||||
|
self.hass.async_create_task(
|
||||||
|
self._async_process_registry_update_or_remove(event)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_process_registry_update_or_remove(
|
||||||
|
self, event: EventType[er.EventEntityRegistryUpdatedData]
|
||||||
|
) -> None:
|
||||||
|
"""Handle entity registry update or remove."""
|
||||||
data = event.data
|
data = event.data
|
||||||
if data["action"] == "remove":
|
if data["action"] == "remove":
|
||||||
await self.async_removed_from_registry()
|
await self.async_removed_from_registry()
|
||||||
|
@ -328,6 +328,7 @@ def _async_track_state_change_event(
|
|||||||
_async_dispatch_entity_id_event,
|
_async_dispatch_entity_id_event,
|
||||||
_async_state_change_filter,
|
_async_state_change_filter,
|
||||||
action,
|
action,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -378,8 +379,16 @@ def _async_track_event(
|
|||||||
bool,
|
bool,
|
||||||
],
|
],
|
||||||
action: Callable[[EventType[_TypedDictT]], None],
|
action: Callable[[EventType[_TypedDictT]], None],
|
||||||
|
run_immediately: bool,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Track an event by a specific key."""
|
"""Track an event by a specific key.
|
||||||
|
|
||||||
|
This function is intended for internal use only.
|
||||||
|
|
||||||
|
The dispatcher_callable, filter_callable, event_type, and run_immediately
|
||||||
|
must always be the same for the listener_key as the first call to this
|
||||||
|
function will set the listener_key in hass.data.
|
||||||
|
"""
|
||||||
if not keys:
|
if not keys:
|
||||||
return _remove_empty_listener
|
return _remove_empty_listener
|
||||||
|
|
||||||
@ -388,10 +397,8 @@ def _async_track_event(
|
|||||||
|
|
||||||
hass_data = hass.data
|
hass_data = hass.data
|
||||||
|
|
||||||
callbacks: dict[
|
callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]] | None
|
||||||
str, list[HassJob[[EventType[_TypedDictT]], Any]]
|
if not (callbacks := hass_data.get(callbacks_key)):
|
||||||
] | None = hass_data.get(callbacks_key)
|
|
||||||
if not callbacks:
|
|
||||||
callbacks = hass_data[callbacks_key] = {}
|
callbacks = hass_data[callbacks_key] = {}
|
||||||
|
|
||||||
if listeners_key not in hass_data:
|
if listeners_key not in hass_data:
|
||||||
@ -399,13 +406,13 @@ def _async_track_event(
|
|||||||
event_type,
|
event_type,
|
||||||
ft.partial(dispatcher_callable, hass, callbacks),
|
ft.partial(dispatcher_callable, hass, callbacks),
|
||||||
event_filter=ft.partial(filter_callable, hass, callbacks),
|
event_filter=ft.partial(filter_callable, hass, callbacks),
|
||||||
|
run_immediately=run_immediately,
|
||||||
)
|
)
|
||||||
|
|
||||||
job = HassJob(action, f"track {event_type} event {keys}")
|
job = HassJob(action, f"track {event_type} event {keys}")
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
callback_list = callbacks.get(key)
|
if callback_list := callbacks.get(key):
|
||||||
if callback_list:
|
|
||||||
callback_list.append(job)
|
callback_list.append(job)
|
||||||
else:
|
else:
|
||||||
callbacks[key] = [job]
|
callbacks[key] = [job]
|
||||||
@ -473,6 +480,7 @@ def async_track_entity_registry_updated_event(
|
|||||||
_async_dispatch_old_entity_id_or_entity_id_event,
|
_async_dispatch_old_entity_id_or_entity_id_event,
|
||||||
_async_entity_registry_updated_filter,
|
_async_entity_registry_updated_filter,
|
||||||
action,
|
action,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -529,6 +537,7 @@ def async_track_device_registry_updated_event(
|
|||||||
_async_dispatch_device_id_event,
|
_async_dispatch_device_id_event,
|
||||||
_async_device_registry_updated_filter,
|
_async_device_registry_updated_filter,
|
||||||
action,
|
action,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -590,6 +599,7 @@ def _async_track_state_added_domain(
|
|||||||
_async_dispatch_domain_event,
|
_async_dispatch_domain_event,
|
||||||
_async_domain_added_filter,
|
_async_domain_added_filter,
|
||||||
action,
|
action,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -622,6 +632,7 @@ def async_track_state_removed_domain(
|
|||||||
_async_dispatch_domain_event,
|
_async_dispatch_domain_event,
|
||||||
_async_domain_removed_filter,
|
_async_domain_removed_filter,
|
||||||
action,
|
action,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2468,3 +2468,91 @@ async def test_entity_report_deprecated_supported_features_values(
|
|||||||
"is using deprecated supported features values which will be removed"
|
"is using deprecated supported features values which will be removed"
|
||||||
not in caplog.text
|
not in caplog.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_entity_registry(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test removing an entity from the registry."""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
entry = entity_registry.async_get_or_create(
|
||||||
|
"test", "test_platform", "5678", suggested_object_id="test"
|
||||||
|
)
|
||||||
|
assert entry.entity_id == "test.test"
|
||||||
|
|
||||||
|
class MockEntity(entity.Entity):
|
||||||
|
_attr_unique_id = "5678"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.added_calls = []
|
||||||
|
self.remove_calls = []
|
||||||
|
|
||||||
|
async def async_added_to_hass(self):
|
||||||
|
self.added_calls.append(None)
|
||||||
|
self.async_on_remove(lambda: result.append(1))
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self):
|
||||||
|
self.remove_calls.append(None)
|
||||||
|
|
||||||
|
platform = MockEntityPlatform(hass, domain="test")
|
||||||
|
ent = MockEntity()
|
||||||
|
await platform.async_add_entities([ent])
|
||||||
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
||||||
|
assert len(ent.added_calls) == 1
|
||||||
|
|
||||||
|
entry = entity_registry.async_remove(entry.entity_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert len(ent.added_calls) == 1
|
||||||
|
assert len(ent.remove_calls) == 1
|
||||||
|
|
||||||
|
assert hass.states.get("test.test") is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reset_right_after_remove_entity_registry(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test resetting the platform right after removing an entity from the registry.
|
||||||
|
|
||||||
|
A reset commonly happens during a reload.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
entry = entity_registry.async_get_or_create(
|
||||||
|
"test", "test_platform", "5678", suggested_object_id="test"
|
||||||
|
)
|
||||||
|
assert entry.entity_id == "test.test"
|
||||||
|
|
||||||
|
class MockEntity(entity.Entity):
|
||||||
|
_attr_unique_id = "5678"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.added_calls = []
|
||||||
|
self.remove_calls = []
|
||||||
|
|
||||||
|
async def async_added_to_hass(self):
|
||||||
|
self.added_calls.append(None)
|
||||||
|
self.async_on_remove(lambda: result.append(1))
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self):
|
||||||
|
self.remove_calls.append(None)
|
||||||
|
|
||||||
|
platform = MockEntityPlatform(hass, domain="test")
|
||||||
|
ent = MockEntity()
|
||||||
|
await platform.async_add_entities([ent])
|
||||||
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
||||||
|
assert len(ent.added_calls) == 1
|
||||||
|
|
||||||
|
entry = entity_registry.async_remove(entry.entity_id)
|
||||||
|
|
||||||
|
# Reset the platform immediately after removing the entity from the registry
|
||||||
|
await platform.async_reset()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert len(ent.added_calls) == 1
|
||||||
|
assert len(ent.remove_calls) == 1
|
||||||
|
|
||||||
|
assert hass.states.get("test.test") is None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user