diff --git a/homeassistant/block_async_io.py b/homeassistant/block_async_io.py index 2dc94fa456a..5b8ba535b5a 100644 --- a/homeassistant/block_async_io.py +++ b/homeassistant/block_async_io.py @@ -1,7 +1,9 @@ """Block blocking calls being done in asyncio.""" import builtins +from collections.abc import Callable from contextlib import suppress +from dataclasses import dataclass import glob from http.client import HTTPConnection import importlib @@ -46,53 +48,131 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool: return False +@dataclass(slots=True, frozen=True) +class BlockingCall: + """Class to hold information about a blocking call.""" + + original_func: Callable + object: object + function: str + check_allowed: Callable[[dict[str, Any]], bool] | None + strict: bool + strict_core: bool + skip_for_tests: bool + + +_BLOCKING_CALLS: tuple[BlockingCall, ...] = ( + BlockingCall( + original_func=HTTPConnection.putrequest, + object=HTTPConnection, + function="putrequest", + check_allowed=None, + strict=True, + strict_core=True, + skip_for_tests=False, + ), + BlockingCall( + original_func=time.sleep, + object=time, + function="sleep", + check_allowed=_check_sleep_call_allowed, + strict=True, + strict_core=True, + skip_for_tests=False, + ), + BlockingCall( + original_func=glob.glob, + object=glob, + function="glob", + check_allowed=None, + strict=False, + strict_core=False, + skip_for_tests=False, + ), + BlockingCall( + original_func=glob.iglob, + object=glob, + function="iglob", + check_allowed=None, + strict=False, + strict_core=False, + skip_for_tests=False, + ), + BlockingCall( + original_func=os.walk, + object=os, + function="walk", + check_allowed=None, + strict=False, + strict_core=False, + skip_for_tests=False, + ), + BlockingCall( + original_func=os.listdir, + object=os, + function="listdir", + check_allowed=None, + strict=False, + strict_core=False, + skip_for_tests=True, + ), + BlockingCall( + original_func=os.scandir, + object=os, + function="scandir", + check_allowed=None, + strict=False, + strict_core=False, + skip_for_tests=True, + ), + BlockingCall( + original_func=builtins.open, + object=builtins, + function="open", + check_allowed=_check_file_allowed, + strict=False, + strict_core=False, + skip_for_tests=True, + ), + BlockingCall( + original_func=importlib.import_module, + object=importlib, + function="import_module", + check_allowed=_check_import_call_allowed, + strict=False, + strict_core=False, + skip_for_tests=True, + ), +) + + +@dataclass(slots=True) +class BlockedCalls: + """Class to track which calls are blocked.""" + + calls: set[BlockingCall] + + +_BLOCKED_CALLS = BlockedCalls(set()) + + def enable() -> None: """Enable the detection of blocking calls in the event loop.""" + calls = _BLOCKED_CALLS.calls + if calls: + raise RuntimeError("Blocking call detection is already enabled") + loop_thread_id = threading.get_ident() - # Prevent urllib3 and requests doing I/O in event loop - HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign] - HTTPConnection.putrequest, loop_thread_id=loop_thread_id - ) + for blocking_call in _BLOCKING_CALLS: + if _IN_TESTS and blocking_call.skip_for_tests: + continue - # Prevent sleeping in event loop. - time.sleep = protect_loop( - time.sleep, - check_allowed=_check_sleep_call_allowed, - loop_thread_id=loop_thread_id, - ) - - glob.glob = protect_loop( - glob.glob, strict_core=False, strict=False, loop_thread_id=loop_thread_id - ) - glob.iglob = protect_loop( - glob.iglob, strict_core=False, strict=False, loop_thread_id=loop_thread_id - ) - os.walk = protect_loop( - os.walk, strict_core=False, strict=False, loop_thread_id=loop_thread_id - ) - - if not _IN_TESTS: - # Prevent files being opened inside the event loop - os.listdir = protect_loop( # type: ignore[assignment] - os.listdir, strict_core=False, strict=False, loop_thread_id=loop_thread_id - ) - os.scandir = protect_loop( # type: ignore[assignment] - os.scandir, strict_core=False, strict=False, loop_thread_id=loop_thread_id - ) - - builtins.open = protect_loop( # type: ignore[assignment] - builtins.open, - strict_core=False, - strict=False, - check_allowed=_check_file_allowed, - loop_thread_id=loop_thread_id, - ) - # unittest uses `importlib.import_module` to do mocking - # so we cannot protect it if we are running tests - importlib.import_module = protect_loop( - importlib.import_module, - strict_core=False, - strict=False, - check_allowed=_check_import_call_allowed, + protected_function = protect_loop( + blocking_call.original_func, + strict=blocking_call.strict, + strict_core=blocking_call.strict_core, + check_allowed=blocking_call.check_allowed, loop_thread_id=loop_thread_id, ) + setattr(blocking_call.object, blocking_call.function, protected_function) + calls.add(blocking_call) diff --git a/tests/conftest.py b/tests/conftest.py index 1d0ad3d47b3..0bef1a7b06a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,8 @@ import requests_mock from syrupy.assertion import SnapshotAssertion from typing_extensions import AsyncGenerator, Generator +from homeassistant import block_async_io + # Setup patching if dt_util time functions before any other Home Assistant imports from . import patch_time # noqa: F401, isort:skip @@ -1814,3 +1816,15 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: """Return snapshot assertion fixture with the Home Assistant extension.""" return snapshot.use_extension(HomeAssistantSnapshotExtension) + + +@pytest.fixture +def disable_block_async_io() -> Generator[Any, Any, None]: + """Fixture to disable the loop protection from block_async_io.""" + yield + calls = block_async_io._BLOCKED_CALLS.calls + for blocking_call in calls: + setattr( + blocking_call.object, blocking_call.function, blocking_call.original_func + ) + calls.clear() diff --git a/tests/test_block_async_io.py b/tests/test_block_async_io.py index d011bdccdbe..d823f8c6912 100644 --- a/tests/test_block_async_io.py +++ b/tests/test_block_async_io.py @@ -17,6 +17,11 @@ from homeassistant.core import HomeAssistant from .common import extract_stack_to_frame +@pytest.fixture(autouse=True) +def disable_block_async_io(disable_block_async_io): + """Disable the loop protection from block_async_io after each test.""" + + async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None: """Test time.sleep injected by the debugger is not reported.""" block_async_io.enable() @@ -214,13 +219,25 @@ async def test_protect_loop_open(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None: """Test opening a file in the event loop logs.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() with contextlib.suppress(FileNotFoundError): open("/config/data_not_exist", encoding="utf8").close() assert "Detected blocking call to open with args" in caplog.text +async def test_enable_multiple_times(caplog: pytest.LogCaptureFixture) -> None: + """Test trying to enable multiple times.""" + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() + + with pytest.raises( + RuntimeError, match="Blocking call detection is already enabled" + ): + block_async_io.enable() + + @pytest.mark.parametrize( "path", [ @@ -231,7 +248,8 @@ async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None: ) async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None: """Test opening a file by path in the event loop logs.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() with contextlib.suppress(FileNotFoundError): open(path, encoding="utf8").close() @@ -242,7 +260,8 @@ async def test_protect_loop_glob( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test glob calls in the loop are logged.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() glob.glob("/dev/null") assert "Detected blocking call to glob with args" in caplog.text caplog.clear() @@ -254,7 +273,8 @@ async def test_protect_loop_iglob( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test iglob calls in the loop are logged.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() glob.iglob("/dev/null") assert "Detected blocking call to iglob with args" in caplog.text caplog.clear() @@ -266,7 +286,8 @@ async def test_protect_loop_scandir( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test glob calls in the loop are logged.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() with contextlib.suppress(FileNotFoundError): os.scandir("/path/that/does/not/exists") assert "Detected blocking call to scandir with args" in caplog.text @@ -280,7 +301,8 @@ async def test_protect_loop_listdir( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test listdir calls in the loop are logged.""" - block_async_io.enable() + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() with contextlib.suppress(FileNotFoundError): os.listdir("/path/that/does/not/exists") assert "Detected blocking call to listdir with args" in caplog.text @@ -293,8 +315,9 @@ async def test_protect_loop_listdir( async def test_protect_loop_walk( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: - """Test glob calls in the loop are logged.""" - block_async_io.enable() + """Test os.walk calls in the loop are logged.""" + with patch.object(block_async_io, "_IN_TESTS", False): + block_async_io.enable() with contextlib.suppress(FileNotFoundError): os.walk("/path/that/does/not/exists") assert "Detected blocking call to walk with args" in caplog.text @@ -302,3 +325,13 @@ async def test_protect_loop_walk( with contextlib.suppress(FileNotFoundError): await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists") assert "Detected blocking call to walk with args" not in caplog.text + + +async def test_open_calls_ignored_in_tests(caplog: pytest.LogCaptureFixture) -> None: + """Test opening a file in tests is ignored.""" + assert block_async_io._IN_TESTS + block_async_io.enable() + with contextlib.suppress(FileNotFoundError): + open("/config/data_not_exist", encoding="utf8").close() + + assert "Detected blocking call to open with args" not in caplog.text diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 9e04421a58a..225720fb604 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -55,6 +55,11 @@ async def apply_stop_hass(stop_hass: None) -> None: """Make sure all hass are stopped.""" +@pytest.fixture(autouse=True) +def disable_block_async_io(disable_block_async_io): + """Disable the loop protection from block_async_io after each test.""" + + @pytest.fixture(scope="module", autouse=True) def mock_http_start_stop() -> Generator[None]: """Mock HTTP start and stop."""