From 9c145b5faafcf374af0d6e325affa597bc3968ff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 20 Feb 2024 20:48:31 -0600 Subject: [PATCH] Fix race in removing entities from the registry (#110978) --- homeassistant/helpers/entity.py | 29 ++++++++++- homeassistant/helpers/event.py | 25 +++++++--- tests/helpers/test_entity.py | 88 +++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 8 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 32aa97ab8fe..f7497e77a94 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -496,6 +496,9 @@ class Entity( # Entry in the entity registry registry_entry: er.RegistryEntry | None = None + # If the entity is removed from the entity registry + _removed_from_registry: bool = False + # The device entry for this entity device_entry: dr.DeviceEntry | None = None @@ -1361,6 +1364,17 @@ class Entity( not force_remove and self.registry_entry and not self.registry_entry.disabled + # Check if entity is still in the entity registry + # by checking self._removed_from_registry + # + # Because self.registry_entry is unset in a task, + # its possible that the entity has been removed but + # the task has not yet been executed. + # + # self._removed_from_registry is set to True in a + # callback which does not have the same issue. + # + and not self._removed_from_registry ): # Set the entity's state will to unavailable + ATTR_RESTORED: True self.registry_entry.write_unavailable_state(self.hass) @@ -1430,10 +1444,23 @@ class Entity( if self.platform: self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id) - async def _async_registry_updated( + @callback + def _async_registry_updated( self, event: EventType[er.EventEntityRegistryUpdatedData] ) -> None: """Handle entity registry update.""" + action = event.data["action"] + is_remove = action == "remove" + self._removed_from_registry = is_remove + if action == "update" or is_remove: + self.hass.async_create_task( + self._async_process_registry_update_or_remove(event) + ) + + async def _async_process_registry_update_or_remove( + self, event: EventType[er.EventEntityRegistryUpdatedData] + ) -> None: + """Handle entity registry update or remove.""" data = event.data if data["action"] == "remove": await self.async_removed_from_registry() diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index f9c2e47dc96..35493dbcacb 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -328,6 +328,7 @@ def _async_track_state_change_event( _async_dispatch_entity_id_event, _async_state_change_filter, action, + False, ) @@ -378,8 +379,16 @@ def _async_track_event( bool, ], action: Callable[[EventType[_TypedDictT]], None], + run_immediately: bool, ) -> CALLBACK_TYPE: - """Track an event by a specific key.""" + """Track an event by a specific key. + + This function is intended for internal use only. + + The dispatcher_callable, filter_callable, event_type, and run_immediately + must always be the same for the listener_key as the first call to this + function will set the listener_key in hass.data. + """ if not keys: return _remove_empty_listener @@ -388,10 +397,8 @@ def _async_track_event( hass_data = hass.data - callbacks: dict[ - str, list[HassJob[[EventType[_TypedDictT]], Any]] - ] | None = hass_data.get(callbacks_key) - if not callbacks: + callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]] | None + if not (callbacks := hass_data.get(callbacks_key)): callbacks = hass_data[callbacks_key] = {} if listeners_key not in hass_data: @@ -399,13 +406,13 @@ def _async_track_event( event_type, ft.partial(dispatcher_callable, hass, callbacks), event_filter=ft.partial(filter_callable, hass, callbacks), + run_immediately=run_immediately, ) job = HassJob(action, f"track {event_type} event {keys}") for key in keys: - callback_list = callbacks.get(key) - if callback_list: + if callback_list := callbacks.get(key): callback_list.append(job) else: callbacks[key] = [job] @@ -473,6 +480,7 @@ def async_track_entity_registry_updated_event( _async_dispatch_old_entity_id_or_entity_id_event, _async_entity_registry_updated_filter, action, + True, ) @@ -529,6 +537,7 @@ def async_track_device_registry_updated_event( _async_dispatch_device_id_event, _async_device_registry_updated_filter, action, + True, ) @@ -590,6 +599,7 @@ def _async_track_state_added_domain( _async_dispatch_domain_event, _async_domain_added_filter, action, + False, ) @@ -622,6 +632,7 @@ def async_track_state_removed_domain( _async_dispatch_domain_event, _async_domain_removed_filter, action, + False, ) diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 19600506ae2..4de38cc814d 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -2468,3 +2468,91 @@ async def test_entity_report_deprecated_supported_features_values( "is using deprecated supported features values which will be removed" not in caplog.text ) + + +async def test_remove_entity_registry( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test removing an entity from the registry.""" + result = [] + + entry = entity_registry.async_get_or_create( + "test", "test_platform", "5678", suggested_object_id="test" + ) + assert entry.entity_id == "test.test" + + class MockEntity(entity.Entity): + _attr_unique_id = "5678" + + def __init__(self) -> None: + self.added_calls = [] + self.remove_calls = [] + + async def async_added_to_hass(self): + self.added_calls.append(None) + self.async_on_remove(lambda: result.append(1)) + + async def async_will_remove_from_hass(self): + self.remove_calls.append(None) + + platform = MockEntityPlatform(hass, domain="test") + ent = MockEntity() + await platform.async_add_entities([ent]) + assert hass.states.get("test.test").state == STATE_UNKNOWN + assert len(ent.added_calls) == 1 + + entry = entity_registry.async_remove(entry.entity_id) + await hass.async_block_till_done() + + assert len(result) == 1 + assert len(ent.added_calls) == 1 + assert len(ent.remove_calls) == 1 + + assert hass.states.get("test.test") is None + + +async def test_reset_right_after_remove_entity_registry( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test resetting the platform right after removing an entity from the registry. + + A reset commonly happens during a reload. + """ + result = [] + + entry = entity_registry.async_get_or_create( + "test", "test_platform", "5678", suggested_object_id="test" + ) + assert entry.entity_id == "test.test" + + class MockEntity(entity.Entity): + _attr_unique_id = "5678" + + def __init__(self) -> None: + self.added_calls = [] + self.remove_calls = [] + + async def async_added_to_hass(self): + self.added_calls.append(None) + self.async_on_remove(lambda: result.append(1)) + + async def async_will_remove_from_hass(self): + self.remove_calls.append(None) + + platform = MockEntityPlatform(hass, domain="test") + ent = MockEntity() + await platform.async_add_entities([ent]) + assert hass.states.get("test.test").state == STATE_UNKNOWN + assert len(ent.added_calls) == 1 + + entry = entity_registry.async_remove(entry.entity_id) + + # Reset the platform immediately after removing the entity from the registry + await platform.async_reset() + await hass.async_block_till_done() + + assert len(result) == 1 + assert len(ent.added_calls) == 1 + assert len(ent.remove_calls) == 1 + + assert hass.states.get("test.test") is None