Move thread safety check in category_registry sooner (#117050)

This commit is contained in:
J. Nick Koston 2024-05-07 19:55:43 -05:00 committed by GitHub
parent 3b51bf266a
commit 8401b05d40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 3 deletions

View File

@ -98,6 +98,7 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
icon: str | None = None, icon: str | None = None,
) -> CategoryEntry: ) -> CategoryEntry:
"""Create a new category.""" """Create a new category."""
self.hass.verify_event_loop_thread("async_create")
self._async_ensure_name_is_available(scope, name) self._async_ensure_name_is_available(scope, name)
category = CategoryEntry( category = CategoryEntry(
icon=icon, icon=icon,
@ -110,7 +111,7 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
self.categories[scope][category.category_id] = category self.categories[scope][category.category_id] = category
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_CATEGORY_REGISTRY_UPDATED, EVENT_CATEGORY_REGISTRY_UPDATED,
EventCategoryRegistryUpdatedData( EventCategoryRegistryUpdatedData(
action="create", scope=scope, category_id=category.category_id action="create", scope=scope, category_id=category.category_id
@ -121,8 +122,9 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
@callback @callback
def async_delete(self, *, scope: str, category_id: str) -> None: def async_delete(self, *, scope: str, category_id: str) -> None:
"""Delete category.""" """Delete category."""
self.hass.verify_event_loop_thread("async_delete")
del self.categories[scope][category_id] del self.categories[scope][category_id]
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_CATEGORY_REGISTRY_UPDATED, EVENT_CATEGORY_REGISTRY_UPDATED,
EventCategoryRegistryUpdatedData( EventCategoryRegistryUpdatedData(
action="remove", action="remove",
@ -155,10 +157,11 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
if not changes: if not changes:
return old return old
self.hass.verify_event_loop_thread("async_update")
new = self.categories[scope][category_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type] new = self.categories[scope][category_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type]
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_CATEGORY_REGISTRY_UPDATED, EVENT_CATEGORY_REGISTRY_UPDATED,
EventCategoryRegistryUpdatedData( EventCategoryRegistryUpdatedData(
action="update", scope=scope, category_id=category_id action="update", scope=scope, category_id=category_id

View File

@ -1,5 +1,6 @@
"""Tests for the category registry.""" """Tests for the category registry."""
from functools import partial
import re import re
from typing import Any from typing import Any
@ -394,3 +395,55 @@ async def test_loading_categories_from_storage(
assert category3.category_id == "uuid3" assert category3.category_id == "uuid3"
assert category3.name == "Grocery stores" assert category3.name == "Grocery stores"
assert category3.icon == "mdi:store" assert category3.icon == "mdi:store"
async def test_async_create_thread_safety(
hass: HomeAssistant, category_registry: cr.CategoryRegistry
) -> None:
"""Test async_create raises when called from wrong thread."""
with pytest.raises(
RuntimeError,
match="Detected code that calls async_create from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(category_registry.async_create, name="any", scope="any")
)
async def test_async_delete_thread_safety(
hass: HomeAssistant, category_registry: cr.CategoryRegistry
) -> None:
"""Test async_delete raises when called from wrong thread."""
any_category = category_registry.async_create(name="any", scope="any")
with pytest.raises(
RuntimeError,
match="Detected code that calls async_delete from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(
category_registry.async_delete,
scope="any",
category_id=any_category.category_id,
)
)
async def test_async_update_thread_safety(
hass: HomeAssistant, category_registry: cr.CategoryRegistry
) -> None:
"""Test async_update raises when called from wrong thread."""
any_category = category_registry.async_create(name="any", scope="any")
with pytest.raises(
RuntimeError,
match="Detected code that calls async_update from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(
category_registry.async_update,
scope="any",
category_id=any_category.category_id,
name="new name",
)
)