From 475c20d5296cc67373af3bc7787260a169e941ec Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 5 Jun 2024 22:41:55 -0500 Subject: [PATCH] Always do thread safety check when writing state (#118886) * Always do thread safety check when writing state Refactor the 3 most common places where the thread safety check for the event loop to be inline to make the check fast enough that we can keep it long term. While code review catches most of the thread safety issues in core, some of them still make it through, and new ones keep getting added. Its not possible to catch them all with manual code review, so its worth the tiny overhead to check each time. Previously the check was limited to custom components because they were the most common source of thread safety issues. * Always do thread safety check when writing state Refactor the 3 most common places where the thread safety check for the event loop to be inline to make the check fast enough that we can keep it long term. While code review catches most of the thread safety issues in core, some of them still make it through, and new ones keep getting added. Its not possible to catch them all with manual code review, so its worth the tiny overhead to check each time. Previously the check was limited to custom components because they were the most common source of thread safety issues. * async_fire is more common than expected with ccs * fix mock * fix hass mocking --- homeassistant/core.py | 35 +++++++------------ homeassistant/helpers/entity.py | 9 +++-- homeassistant/helpers/frame.py | 13 +++++++ tests/common.py | 2 +- tests/components/zha/test_cluster_handlers.py | 2 ++ tests/helpers/test_entity.py | 6 ++-- 6 files changed, 34 insertions(+), 33 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index ad04c6d1366..d0e80ad8bd1 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -434,25 +434,17 @@ class HomeAssistant: self.import_executor = InterruptibleThreadPoolExecutor( max_workers=1, thread_name_prefix="ImportExecutor" ) - self._loop_thread_id = getattr( + self.loop_thread_id = getattr( self.loop, "_thread_ident", getattr(self.loop, "_thread_id") ) def verify_event_loop_thread(self, what: str) -> None: """Report and raise if we are not running in the event loop thread.""" - if self._loop_thread_id != threading.get_ident(): + if self.loop_thread_id != threading.get_ident(): + # frame is a circular import, so we import it here from .helpers import frame # pylint: disable=import-outside-toplevel - # frame is a circular import, so we import it here - frame.report( - f"calls {what} from a thread other than the event loop, " - "which may cause Home Assistant to crash or data to corrupt. " - "For more information, see " - "https://developers.home-assistant.io/docs/asyncio_thread_safety/" - f"#{what.replace('.', '')}", - error_if_core=True, - error_if_integration=True, - ) + frame.report_non_thread_safe_operation(what) @property def _active_tasks(self) -> set[asyncio.Future[Any]]: @@ -793,16 +785,10 @@ class HomeAssistant: target: target to call. """ - # We turned on asyncio debug in April 2024 in the dev containers - # in the hope of catching some of the issues that have been - # reported. It will take a while to get all the issues fixed in - # custom components. - # - # In 2025.5 we should guard the `verify_event_loop_thread` - # check with a check for the `hass.config.debug` flag being set as - # long term we don't want to be checking this in production - # environments since it is a performance hit. - self.verify_event_loop_thread("hass.async_create_task") + if self.loop_thread_id != threading.get_ident(): + from .helpers import frame # pylint: disable=import-outside-toplevel + + frame.report_non_thread_safe_operation("hass.async_create_task") return self.async_create_task_internal(target, name, eager_start) @callback @@ -1497,7 +1483,10 @@ class EventBus: This method must be run in the event loop. """ _verify_event_type_length_or_raise(event_type) - self._hass.verify_event_loop_thread("hass.bus.async_fire") + if self._hass.loop_thread_id != threading.get_ident(): + from .helpers import frame # pylint: disable=import-outside-toplevel + + frame.report_non_thread_safe_operation("hass.bus.async_fire") return self.async_fire_internal( event_type, event_data, origin, context, time_fired ) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index ee544883a68..9a2bb4b6fca 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -14,6 +14,7 @@ import logging import math from operator import attrgetter import sys +import threading import time from types import FunctionType from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final @@ -63,6 +64,7 @@ from .event import ( async_track_device_registry_updated_event, async_track_entity_registry_updated_event, ) +from .frame import report_non_thread_safe_operation from .typing import UNDEFINED, StateType, UndefinedType timer = time.time @@ -512,7 +514,6 @@ class Entity( # While not purely typed, it makes typehinting more useful for us # and removes the need for constant None checks or asserts. _state_info: StateInfo = None # type: ignore[assignment] - _is_custom_component: bool = False __capabilities_updated_at: deque[float] __capabilities_updated_at_reported: bool = False @@ -995,8 +996,8 @@ class Entity( def async_write_ha_state(self) -> None: """Write the state to the state machine.""" self._async_verify_state_writable() - if self._is_custom_component or self.hass.config.debug: - self.hass.verify_event_loop_thread("async_write_ha_state") + if self.hass.loop_thread_id != threading.get_ident(): + report_non_thread_safe_operation("async_write_ha_state") self._async_write_ha_state() def _stringify_state(self, available: bool) -> str: @@ -1440,8 +1441,6 @@ class Entity( "domain": self.platform.platform_name, "custom_component": is_custom_component, } - self._is_custom_component = is_custom_component - if self.platform.config_entry: entity_info["config_entry"] = self.platform.config_entry.entry_id diff --git a/homeassistant/helpers/frame.py b/homeassistant/helpers/frame.py index e8ba6ba0c07..8a30c26886e 100644 --- a/homeassistant/helpers/frame.py +++ b/homeassistant/helpers/frame.py @@ -218,3 +218,16 @@ def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT: report(what) return cast(_CallableT, report_use) + + +def report_non_thread_safe_operation(what: str) -> None: + """Report a non-thread safe operation.""" + report( + f"calls {what} from a thread other than the event loop, " + "which may cause Home Assistant to crash or data to corrupt. " + "For more information, see " + "https://developers.home-assistant.io/docs/asyncio_thread_safety/" + f"#{what.replace('.', '')}", + error_if_core=True, + error_if_integration=True, + ) diff --git a/tests/common.py b/tests/common.py index b1110297d2f..88d7a86fcf4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -174,7 +174,7 @@ def get_test_home_assistant() -> Generator[HomeAssistant, None, None]: """Run event loop.""" loop._thread_ident = threading.get_ident() - hass._loop_thread_id = loop._thread_ident + hass.loop_thread_id = loop._thread_ident loop.run_forever() loop_stop_event.set() diff --git a/tests/components/zha/test_cluster_handlers.py b/tests/components/zha/test_cluster_handlers.py index cc9fb8d1918..d09883c38e3 100644 --- a/tests/components/zha/test_cluster_handlers.py +++ b/tests/components/zha/test_cluster_handlers.py @@ -3,6 +3,7 @@ from collections.abc import Callable import logging import math +import threading from types import NoneType from unittest import mock from unittest.mock import AsyncMock, patch @@ -86,6 +87,7 @@ def endpoint(zigpy_coordinator_device): type(endpoint_mock.device).skip_configuration = mock.PropertyMock( return_value=False ) + endpoint_mock.device.hass.loop_thread_id = threading.get_ident() endpoint_mock.id = 1 return endpoint_mock diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index a80674e0f76..c8da7a118aa 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -2617,13 +2617,12 @@ async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None: assert not hass.states.get(ent2.entity_id) -async def test_async_write_ha_state_thread_safety_custom_component( +async def test_async_write_ha_state_thread_safety_always( hass: HomeAssistant, ) -> None: - """Test async_write_ha_state thread safe for custom components.""" + """Test async_write_ha_state thread safe check.""" ent = entity.Entity() - ent._is_custom_component = True ent.entity_id = "test.any" ent.hass = hass ent.platform = MockEntityPlatform(hass, domain="test") @@ -2631,7 +2630,6 @@ async def test_async_write_ha_state_thread_safety_custom_component( assert hass.states.get(ent.entity_id) ent2 = entity.Entity() - ent2._is_custom_component = True ent2.entity_id = "test.any2" ent2.hass = hass ent2.platform = MockEntityPlatform(hass, domain="test")