mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Move thread safety check in entity_registry sooner (#116263)
* Move thread safety check in entity_registry sooner It turns out we have a lot of custom components that are writing to the entity registry using the async APIs from threads. We now catch it at the point async_fire is called. Instread we should check sooner and use async_fire_internal so we catch the unsafe operation before it can corrupt the registry. * coverage * Apply suggestions from code review
This commit is contained in:
parent
603f46184c
commit
8d11a9f21a
@ -819,6 +819,7 @@ class EntityRegistry(BaseRegistry):
|
|||||||
unit_of_measurement=unit_of_measurement,
|
unit_of_measurement=unit_of_measurement,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.hass.verify_event_loop_thread("async_get_or_create")
|
||||||
_validate_item(
|
_validate_item(
|
||||||
self.hass,
|
self.hass,
|
||||||
domain,
|
domain,
|
||||||
@ -879,7 +880,7 @@ class EntityRegistry(BaseRegistry):
|
|||||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire_internal(
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
_EventEntityRegistryUpdatedData_CreateRemove(
|
_EventEntityRegistryUpdatedData_CreateRemove(
|
||||||
action="create", entity_id=entity_id
|
action="create", entity_id=entity_id
|
||||||
@ -891,6 +892,7 @@ class EntityRegistry(BaseRegistry):
|
|||||||
@callback
|
@callback
|
||||||
def async_remove(self, entity_id: str) -> None:
|
def async_remove(self, entity_id: str) -> None:
|
||||||
"""Remove an entity from registry."""
|
"""Remove an entity from registry."""
|
||||||
|
self.hass.verify_event_loop_thread("async_remove")
|
||||||
entity = self.entities.pop(entity_id)
|
entity = self.entities.pop(entity_id)
|
||||||
config_entry_id = entity.config_entry_id
|
config_entry_id = entity.config_entry_id
|
||||||
key = (entity.domain, entity.platform, entity.unique_id)
|
key = (entity.domain, entity.platform, entity.unique_id)
|
||||||
@ -904,7 +906,7 @@ class EntityRegistry(BaseRegistry):
|
|||||||
platform=entity.platform,
|
platform=entity.platform,
|
||||||
unique_id=entity.unique_id,
|
unique_id=entity.unique_id,
|
||||||
)
|
)
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire_internal(
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
_EventEntityRegistryUpdatedData_CreateRemove(
|
_EventEntityRegistryUpdatedData_CreateRemove(
|
||||||
action="remove", entity_id=entity_id
|
action="remove", entity_id=entity_id
|
||||||
@ -1085,6 +1087,8 @@ class EntityRegistry(BaseRegistry):
|
|||||||
if not new_values:
|
if not new_values:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
|
self.hass.verify_event_loop_thread("_async_update_entity")
|
||||||
|
|
||||||
new = self.entities[entity_id] = attr.evolve(old, **new_values)
|
new = self.entities[entity_id] = attr.evolve(old, **new_values)
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
@ -1098,7 +1102,7 @@ class EntityRegistry(BaseRegistry):
|
|||||||
if old.entity_id != entity_id:
|
if old.entity_id != entity_id:
|
||||||
data["old_entity_id"] = old.entity_id
|
data["old_entity_id"] = old.entity_id
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
|
self.hass.bus.async_fire_internal(EVENT_ENTITY_REGISTRY_UPDATED, data)
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Tests for the Entity Registry."""
|
"""Tests for the Entity Registry."""
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -1988,3 +1989,46 @@ async def test_entries_for_category(entity_registry: er.EntityRegistry) -> None:
|
|||||||
assert not er.async_entries_for_category(entity_registry, "", "id")
|
assert not er.async_entries_for_category(entity_registry, "", "id")
|
||||||
assert not er.async_entries_for_category(entity_registry, "scope1", "unknown")
|
assert not er.async_entries_for_category(entity_registry, "scope1", "unknown")
|
||||||
assert not er.async_entries_for_category(entity_registry, "scope1", "")
|
assert not er.async_entries_for_category(entity_registry, "scope1", "")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_or_create_thread_safety(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test call async_get_or_create_from a thread."""
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls async_get_or_create from a thread. Please report this issue.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(
|
||||||
|
entity_registry.async_get_or_create, "light", "hue", "1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_update_entity_thread_safety(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test call async_get_or_create from a thread."""
|
||||||
|
entry = entity_registry.async_get_or_create("light", "hue", "1234")
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls _async_update_entity from a thread. Please report this issue.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(
|
||||||
|
partial(
|
||||||
|
entity_registry.async_update_entity,
|
||||||
|
entry.entity_id,
|
||||||
|
new_unique_id="5678",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_remove_thread_safety(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test call async_remove from a thread."""
|
||||||
|
entry = entity_registry.async_get_or_create("light", "hue", "1234")
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls async_remove from a thread. Please report this issue.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(entity_registry.async_remove, entry.entity_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user