Move thread safety check in entity_registry sooner (#116263)

* Move thread safety check in entity_registry sooner

It turns out we have a lot of custom components that are writing
to the entity registry using the async APIs from threads. We now
catch it at the point async_fire is called. Instread we should check
sooner and use async_fire_internal so we catch the unsafe operation
before it can corrupt the registry.

* coverage

* Apply suggestions from code review
This commit is contained in:
J. Nick Koston 2024-04-27 02:25:19 -05:00 committed by GitHub
parent 4a79e750a1
commit 244433aeca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 3 deletions

View File

@ -819,6 +819,7 @@ class EntityRegistry(BaseRegistry):
unit_of_measurement=unit_of_measurement, unit_of_measurement=unit_of_measurement,
) )
self.hass.verify_event_loop_thread("async_get_or_create")
_validate_item( _validate_item(
self.hass, self.hass,
domain, domain,
@ -879,7 +880,7 @@ class EntityRegistry(BaseRegistry):
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_ENTITY_REGISTRY_UPDATED, EVENT_ENTITY_REGISTRY_UPDATED,
_EventEntityRegistryUpdatedData_CreateRemove( _EventEntityRegistryUpdatedData_CreateRemove(
action="create", entity_id=entity_id action="create", entity_id=entity_id
@ -891,6 +892,7 @@ class EntityRegistry(BaseRegistry):
@callback @callback
def async_remove(self, entity_id: str) -> None: def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry.""" """Remove an entity from registry."""
self.hass.verify_event_loop_thread("async_remove")
entity = self.entities.pop(entity_id) entity = self.entities.pop(entity_id)
config_entry_id = entity.config_entry_id config_entry_id = entity.config_entry_id
key = (entity.domain, entity.platform, entity.unique_id) key = (entity.domain, entity.platform, entity.unique_id)
@ -904,7 +906,7 @@ class EntityRegistry(BaseRegistry):
platform=entity.platform, platform=entity.platform,
unique_id=entity.unique_id, unique_id=entity.unique_id,
) )
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_ENTITY_REGISTRY_UPDATED, EVENT_ENTITY_REGISTRY_UPDATED,
_EventEntityRegistryUpdatedData_CreateRemove( _EventEntityRegistryUpdatedData_CreateRemove(
action="remove", entity_id=entity_id action="remove", entity_id=entity_id
@ -1085,6 +1087,8 @@ class EntityRegistry(BaseRegistry):
if not new_values: if not new_values:
return old return old
self.hass.verify_event_loop_thread("_async_update_entity")
new = self.entities[entity_id] = attr.evolve(old, **new_values) new = self.entities[entity_id] = attr.evolve(old, **new_values)
self.async_schedule_save() self.async_schedule_save()
@ -1098,7 +1102,7 @@ class EntityRegistry(BaseRegistry):
if old.entity_id != entity_id: if old.entity_id != entity_id:
data["old_entity_id"] = old.entity_id data["old_entity_id"] = old.entity_id
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data) self.hass.bus.async_fire_internal(EVENT_ENTITY_REGISTRY_UPDATED, data)
return new return new

View File

@ -1,6 +1,7 @@
"""Tests for the Entity Registry.""" """Tests for the Entity Registry."""
from datetime import timedelta from datetime import timedelta
from functools import partial
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@ -1988,3 +1989,46 @@ async def test_entries_for_category(entity_registry: er.EntityRegistry) -> None:
assert not er.async_entries_for_category(entity_registry, "", "id") assert not er.async_entries_for_category(entity_registry, "", "id")
assert not er.async_entries_for_category(entity_registry, "scope1", "unknown") assert not er.async_entries_for_category(entity_registry, "scope1", "unknown")
assert not er.async_entries_for_category(entity_registry, "scope1", "") assert not er.async_entries_for_category(entity_registry, "scope1", "")
async def test_get_or_create_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_get_or_create_from a thread."""
with pytest.raises(
RuntimeError,
match="Detected code that calls async_get_or_create from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
entity_registry.async_get_or_create, "light", "hue", "1234"
)
async def test_async_update_entity_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_get_or_create from a thread."""
entry = entity_registry.async_get_or_create("light", "hue", "1234")
with pytest.raises(
RuntimeError,
match="Detected code that calls _async_update_entity from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(
entity_registry.async_update_entity,
entry.entity_id,
new_unique_id="5678",
)
)
async def test_async_remove_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_remove from a thread."""
entry = entity_registry.async_get_or_create("light", "hue", "1234")
with pytest.raises(
RuntimeError,
match="Detected code that calls async_remove from a thread. Please report this issue.",
):
await hass.async_add_executor_job(entity_registry.async_remove, entry.entity_id)