Move thread safety check in entity_registry sooner (#116263)

* Move thread safety check in entity_registry sooner

It turns out we have a lot of custom components that are writing
to the entity registry using the async APIs from threads. We now
catch it at the point async_fire is called. Instread we should check
sooner and use async_fire_internal so we catch the unsafe operation
before it can corrupt the registry.

* coverage

* Apply suggestions from code review
This commit is contained in:
J. Nick Koston 2024-04-27 02:25:19 -05:00 committed by Paulus Schoutsen
parent 603f46184c
commit 8d11a9f21a
2 changed files with 51 additions and 3 deletions

View File

@ -819,6 +819,7 @@ class EntityRegistry(BaseRegistry):
unit_of_measurement=unit_of_measurement,
)
self.hass.verify_event_loop_thread("async_get_or_create")
_validate_item(
self.hass,
domain,
@ -879,7 +880,7 @@ class EntityRegistry(BaseRegistry):
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save()
self.hass.bus.async_fire(
self.hass.bus.async_fire_internal(
EVENT_ENTITY_REGISTRY_UPDATED,
_EventEntityRegistryUpdatedData_CreateRemove(
action="create", entity_id=entity_id
@ -891,6 +892,7 @@ class EntityRegistry(BaseRegistry):
@callback
def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry."""
self.hass.verify_event_loop_thread("async_remove")
entity = self.entities.pop(entity_id)
config_entry_id = entity.config_entry_id
key = (entity.domain, entity.platform, entity.unique_id)
@ -904,7 +906,7 @@ class EntityRegistry(BaseRegistry):
platform=entity.platform,
unique_id=entity.unique_id,
)
self.hass.bus.async_fire(
self.hass.bus.async_fire_internal(
EVENT_ENTITY_REGISTRY_UPDATED,
_EventEntityRegistryUpdatedData_CreateRemove(
action="remove", entity_id=entity_id
@ -1085,6 +1087,8 @@ class EntityRegistry(BaseRegistry):
if not new_values:
return old
self.hass.verify_event_loop_thread("_async_update_entity")
new = self.entities[entity_id] = attr.evolve(old, **new_values)
self.async_schedule_save()
@ -1098,7 +1102,7 @@ class EntityRegistry(BaseRegistry):
if old.entity_id != entity_id:
data["old_entity_id"] = old.entity_id
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
self.hass.bus.async_fire_internal(EVENT_ENTITY_REGISTRY_UPDATED, data)
return new

View File

@ -1,6 +1,7 @@
"""Tests for the Entity Registry."""
from datetime import timedelta
from functools import partial
from typing import Any
from unittest.mock import patch
@ -1988,3 +1989,46 @@ async def test_entries_for_category(entity_registry: er.EntityRegistry) -> None:
assert not er.async_entries_for_category(entity_registry, "", "id")
assert not er.async_entries_for_category(entity_registry, "scope1", "unknown")
assert not er.async_entries_for_category(entity_registry, "scope1", "")
async def test_get_or_create_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_get_or_create_from a thread."""
with pytest.raises(
RuntimeError,
match="Detected code that calls async_get_or_create from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
entity_registry.async_get_or_create, "light", "hue", "1234"
)
async def test_async_update_entity_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_get_or_create from a thread."""
entry = entity_registry.async_get_or_create("light", "hue", "1234")
with pytest.raises(
RuntimeError,
match="Detected code that calls _async_update_entity from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(
entity_registry.async_update_entity,
entry.entity_id,
new_unique_id="5678",
)
)
async def test_async_remove_thread_safety(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test call async_remove from a thread."""
entry = entity_registry.async_get_or_create("light", "hue", "1234")
with pytest.raises(
RuntimeError,
match="Detected code that calls async_remove from a thread. Please report this issue.",
):
await hass.async_add_executor_job(entity_registry.async_remove, entry.entity_id)