diff --git a/homeassistant/block_async_io.py b/homeassistant/block_async_io.py index a2c187fc537..5d2570fe311 100644 --- a/homeassistant/block_async_io.py +++ b/homeassistant/block_async_io.py @@ -4,6 +4,7 @@ from contextlib import suppress from http.client import HTTPConnection import importlib import sys +import threading import time from typing import Any @@ -25,7 +26,7 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool: # I/O and we are trying to avoid blocking calls. # # frame[0] is us - # frame[1] is check_loop + # frame[1] is raise_for_blocking_call # frame[2] is protected_loop_func # frame[3] is the offender with suppress(ValueError): @@ -35,14 +36,18 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool: def enable() -> None: """Enable the detection of blocking calls in the event loop.""" + loop_thread_id = threading.get_ident() # Prevent urllib3 and requests doing I/O in event loop HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign] - HTTPConnection.putrequest + HTTPConnection.putrequest, loop_thread_id=loop_thread_id ) # Prevent sleeping in event loop. Non-strict since 2022.02 time.sleep = protect_loop( - time.sleep, strict=False, check_allowed=_check_sleep_call_allowed + time.sleep, + strict=False, + check_allowed=_check_sleep_call_allowed, + loop_thread_id=loop_thread_id, ) # Currently disabled. pytz doing I/O when getting timezone. @@ -57,4 +62,5 @@ def enable() -> None: strict_core=False, strict=False, check_allowed=_check_import_call_allowed, + loop_thread_id=loop_thread_id, ) diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index cfad189e823..7bf08a459d7 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -1,5 +1,6 @@ """A pool for sqlite connections.""" +import asyncio import logging import threading import traceback @@ -14,7 +15,7 @@ from sqlalchemy.pool import ( ) from homeassistant.helpers.frame import report -from homeassistant.util.loop import check_loop +from homeassistant.util.loop import raise_for_blocking_call _LOGGER = logging.getLogger(__name__) @@ -86,15 +87,22 @@ class RecorderPool(SingletonThreadPool, NullPool): if threading.get_ident() in self.recorder_and_worker_thread_ids: super().dispose() - def _do_get(self) -> ConnectionPoolEntry: + def _do_get(self) -> ConnectionPoolEntry: # type: ignore[return] if threading.get_ident() in self.recorder_and_worker_thread_ids: return super()._do_get() - check_loop( + try: + asyncio.get_running_loop() + except RuntimeError: + # Not in an event loop but not in the recorder or worker thread + # which is allowed but discouraged since its much slower + return self._do_get_db_connection_protected() + # In the event loop, raise an exception + raise_for_blocking_call( self._do_get_db_connection_protected, strict=True, advise_msg=ADVISE_MSG, ) - return self._do_get_db_connection_protected() + # raise_for_blocking_call will raise an exception def _do_get_db_connection_protected(self) -> ConnectionPoolEntry: report( diff --git a/homeassistant/util/loop.py b/homeassistant/util/loop.py index f8fe5c701f3..071eb42149b 100644 --- a/homeassistant/util/loop.py +++ b/homeassistant/util/loop.py @@ -2,12 +2,12 @@ from __future__ import annotations -from asyncio import get_running_loop from collections.abc import Callable from contextlib import suppress import functools import linecache import logging +import threading from typing import Any, ParamSpec, TypeVar from homeassistant.core import HomeAssistant, async_get_hass @@ -31,7 +31,7 @@ def _get_line_from_cache(filename: str, lineno: int) -> str: return (linecache.getline(filename, lineno) or "?").strip() -def check_loop( +def raise_for_blocking_call( func: Callable[..., Any], check_allowed: Callable[[dict[str, Any]], bool] | None = None, strict: bool = True, @@ -44,15 +44,6 @@ def check_loop( The default advisory message is 'Use `await hass.async_add_executor_job()' Set `advise_msg` to an alternate message if the solution differs. """ - try: - get_running_loop() - in_loop = True - except RuntimeError: - in_loop = False - - if not in_loop: - return - if check_allowed is not None and check_allowed(mapped_args): return @@ -125,6 +116,7 @@ def check_loop( def protect_loop( func: Callable[_P, _R], + loop_thread_id: int, strict: bool = True, strict_core: bool = True, check_allowed: Callable[[dict[str, Any]], bool] | None = None, @@ -133,14 +125,15 @@ def protect_loop( @functools.wraps(func) def protected_loop_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: - check_loop( - func, - strict=strict, - strict_core=strict_core, - check_allowed=check_allowed, - args=args, - kwargs=kwargs, - ) + if threading.get_ident() == loop_thread_id: + raise_for_blocking_call( + func, + strict=strict, + strict_core=strict_core, + check_allowed=check_allowed, + args=args, + kwargs=kwargs, + ) return func(*args, **kwargs) return protected_loop_func diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 71705c060a2..88fbf8f388a 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -159,14 +159,18 @@ async def test_shutdown_before_startup_finishes( await recorder_helper.async_wait_recorder(hass) instance = get_instance(hass) - session = await hass.async_add_executor_job(instance.get_session) + session = await instance.async_add_executor_job(instance.get_session) with patch.object(instance, "engine"): hass.bus.async_fire(EVENT_HOMEASSISTANT_FINAL_WRITE) await hass.async_block_till_done() await hass.async_stop() - run_info = await hass.async_add_executor_job(run_information_with_session, session) + def _run_information_with_session(): + instance.recorder_and_worker_thread_ids.add(threading.get_ident()) + return run_information_with_session(session) + + run_info = await instance.async_add_executor_job(_run_information_with_session) assert run_info.run_id == 1 assert run_info.start is not None @@ -1693,7 +1697,8 @@ async def test_database_corruption_while_running( await hass.async_block_till_done() caplog.clear() - original_start_time = get_instance(hass).recorder_runs_manager.recording_start + instance = get_instance(hass) + original_start_time = instance.recorder_runs_manager.recording_start hass.states.async_set("test.lost", "on", {}) @@ -1737,11 +1742,11 @@ async def test_database_corruption_while_running( assert db_states[0].event_id is None return db_states[0].to_native() - state = await hass.async_add_executor_job(_get_last_state) + state = await instance.async_add_executor_job(_get_last_state) assert state.entity_id == "test.two" assert state.state == "on" - new_start_time = get_instance(hass).recorder_runs_manager.recording_start + new_start_time = instance.recorder_runs_manager.recording_start assert original_start_time < new_start_time hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) @@ -1850,7 +1855,7 @@ async def test_database_lock_and_unlock( assert instance.unlock_database() await task - db_events = await hass.async_add_executor_job(_get_db_events) + db_events = await instance.async_add_executor_job(_get_db_events) assert len(db_events) == 1 diff --git a/tests/components/recorder/test_statistics_v23_migration.py b/tests/components/recorder/test_statistics_v23_migration.py index 28c7613e761..ac48f0d0994 100644 --- a/tests/components/recorder/test_statistics_v23_migration.py +++ b/tests/components/recorder/test_statistics_v23_migration.py @@ -9,12 +9,13 @@ import importlib import json from pathlib import Path import sys +import threading from unittest.mock import patch import pytest from homeassistant.components import recorder -from homeassistant.components.recorder import SQLITE_URL_PREFIX +from homeassistant.components.recorder import SQLITE_URL_PREFIX, get_instance from homeassistant.components.recorder.util import session_scope from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import setup_component @@ -176,6 +177,7 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> ): recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) wait_recording_done(hass) wait_recording_done(hass) @@ -358,6 +360,7 @@ def test_delete_duplicates_many( ): recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) wait_recording_done(hass) wait_recording_done(hass) @@ -517,6 +520,7 @@ def test_delete_duplicates_non_identical( ): recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) wait_recording_done(hass) wait_recording_done(hass) @@ -631,6 +635,7 @@ def test_delete_duplicates_short_term( ): recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) wait_recording_done(hass) wait_recording_done(hass) diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index f6fba72bd5d..db411f83c91 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta import os from pathlib import Path import sqlite3 +import threading from unittest.mock import MagicMock, Mock, patch import pytest @@ -843,9 +844,7 @@ async def test_periodic_db_cleanups( assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);" -@patch("homeassistant.components.recorder.pool.check_loop") async def test_write_lock_db( - skip_check_loop, async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant, tmp_path: Path, @@ -864,6 +863,7 @@ async def test_write_lock_db( with instance.engine.connect() as connection: connection.execute(text("DROP TABLE events;")) + instance.recorder_and_worker_thread_ids.add(threading.get_ident()) with util.write_lock_db_sqlite(instance), pytest.raises(OperationalError): # Database should be locked now, try writing SQL command # This needs to be called in another thread since @@ -872,7 +872,7 @@ async def test_write_lock_db( # in the same thread as the one holding the lock since it # would be allowed to proceed as the goal is to prevent # all the other threads from accessing the database - await hass.async_add_executor_job(_drop_table) + await instance.async_add_executor_job(_drop_table) def test_is_second_sunday() -> None: diff --git a/tests/components/sensor/test_recorder_missing_stats.py b/tests/components/sensor/test_recorder_missing_stats.py index 88c98e6589f..d770c459426 100644 --- a/tests/components/sensor/test_recorder_missing_stats.py +++ b/tests/components/sensor/test_recorder_missing_stats.py @@ -2,11 +2,13 @@ from datetime import datetime, timedelta from pathlib import Path +import threading from unittest.mock import patch from freezegun.api import FrozenDateTimeFactory import pytest +from homeassistant.components.recorder import get_instance from homeassistant.components.recorder.history import get_significant_states from homeassistant.components.recorder.statistics import ( get_latest_short_term_statistics_with_session, @@ -57,6 +59,7 @@ def test_compile_missing_statistics( recorder_helper.async_initialize_recorder(hass) setup_component(hass, "sensor", {}) setup_component(hass, "recorder", {"recorder": config}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) hass.start() wait_recording_done(hass) wait_recording_done(hass) @@ -98,6 +101,7 @@ def test_compile_missing_statistics( setup_component(hass, "sensor", {}) hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES) setup_component(hass, "recorder", {"recorder": config}) + get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident()) hass.start() wait_recording_done(hass) wait_recording_done(hass) diff --git a/tests/util/test_loop.py b/tests/util/test_loop.py index 8b4465bef2b..c3cfb3d0f06 100644 --- a/tests/util/test_loop.py +++ b/tests/util/test_loop.py @@ -1,9 +1,11 @@ """Tests for async util methods from Python source.""" +import threading from unittest.mock import Mock, patch import pytest +from homeassistant.core import HomeAssistant from homeassistant.util import loop as haloop from tests.common import extract_stack_to_frame @@ -13,22 +15,24 @@ def banned_function(): """Mock banned function.""" -async def test_check_loop_async() -> None: - """Test check_loop detects when called from event loop without integration context.""" +async def test_raise_for_blocking_call_async() -> None: + """Test raise_for_blocking_call detects when called from event loop without integration context.""" with pytest.raises(RuntimeError): - haloop.check_loop(banned_function) + haloop.raise_for_blocking_call(banned_function) -async def test_check_loop_async_non_strict_core( +async def test_raise_for_blocking_call_async_non_strict_core( caplog: pytest.LogCaptureFixture, ) -> None: - """Test non_strict_core check_loop detects from event loop without integration context.""" - haloop.check_loop(banned_function, strict_core=False) + """Test non_strict_core raise_for_blocking_call detects from event loop without integration context.""" + haloop.raise_for_blocking_call(banned_function, strict_core=False) assert "Detected blocking call to banned_function" in caplog.text -async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> None: - """Test check_loop detects and raises when called from event loop from integration context.""" +async def test_raise_for_blocking_call_async_integration( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test raise_for_blocking_call detects and raises when called from event loop from integration context.""" frames = extract_stack_to_frame( [ Mock( @@ -67,7 +71,7 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> return_value=frames, ), ): - haloop.check_loop(banned_function) + haloop.raise_for_blocking_call(banned_function) assert ( "Detected blocking call to banned_function inside the event loop by integration" " 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on " @@ -77,10 +81,10 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> ) -async def test_check_loop_async_integration_non_strict( +async def test_raise_for_blocking_call_async_integration_non_strict( caplog: pytest.LogCaptureFixture, ) -> None: - """Test check_loop detects when called from event loop from integration context.""" + """Test raise_for_blocking_call detects when called from event loop from integration context.""" frames = extract_stack_to_frame( [ Mock( @@ -118,7 +122,7 @@ async def test_check_loop_async_integration_non_strict( return_value=frames, ), ): - haloop.check_loop(banned_function, strict=False) + haloop.raise_for_blocking_call(banned_function, strict=False) assert ( "Detected blocking call to banned_function inside the event loop by integration" " 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on " @@ -128,8 +132,10 @@ async def test_check_loop_async_integration_non_strict( ) -async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None: - """Test check_loop detects when called from event loop with custom component context.""" +async def test_raise_for_blocking_call_async_custom( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test raise_for_blocking_call detects when called from event loop with custom component context.""" frames = extract_stack_to_frame( [ Mock( @@ -168,7 +174,7 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None return_value=frames, ), ): - haloop.check_loop(banned_function) + haloop.raise_for_blocking_call(banned_function) assert ( "Detected blocking call to banned_function inside the event loop by custom " "integration 'hue' at custom_components/hue/light.py, line 23: self.light.is_on" @@ -178,18 +184,23 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None ) in caplog.text -def test_check_loop_sync(caplog: pytest.LogCaptureFixture) -> None: - """Test check_loop does nothing when called from thread.""" - haloop.check_loop(banned_function) +async def test_raise_for_blocking_call_sync( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test raise_for_blocking_call does nothing when called from thread.""" + func = haloop.protect_loop(banned_function, threading.get_ident()) + await hass.async_add_executor_job(func) assert "Detected blocking call inside the event loop" not in caplog.text -def test_protect_loop_sync() -> None: - """Test protect_loop calls check_loop.""" +async def test_protect_loop_async() -> None: + """Test protect_loop calls raise_for_blocking_call.""" func = Mock() - with patch("homeassistant.util.loop.check_loop") as mock_check_loop: - haloop.protect_loop(func)(1, test=2) - mock_check_loop.assert_called_once_with( + with patch( + "homeassistant.util.loop.raise_for_blocking_call" + ) as mock_raise_for_blocking_call: + haloop.protect_loop(func, threading.get_ident())(1, test=2) + mock_raise_for_blocking_call.assert_called_once_with( func, strict=True, args=(1,),