Ensure asyncio blocking checks are undone after tests run (#119542)

* Ensure asyncio blocking checks are undone after tests run

* no reason to ever enable twice

* we are patching objects, make it more generic

* make sure bootstrap unblocks as well

* move disable to tests only

* re-protect

* Update tests/test_block_async_io.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Revert "Update tests/test_block_async_io.py"

This reverts commit 2d46028e21b4095479302629a201c3cfc811b2c2.

* tweak name

* fixture only

* Update tests/conftest.py

* Update tests/conftest.py

* Apply suggestions from code review

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2024-06-13 01:52:01 -05:00 committed by GitHub
parent 669569ca49
commit d52ce03aa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 184 additions and 52 deletions

View File

@ -1,7 +1,9 @@
"""Block blocking calls being done in asyncio.""" """Block blocking calls being done in asyncio."""
import builtins import builtins
from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass
import glob import glob
from http.client import HTTPConnection from http.client import HTTPConnection
import importlib import importlib
@ -46,53 +48,131 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
return False return False
@dataclass(slots=True, frozen=True)
class BlockingCall:
"""Class to hold information about a blocking call."""
original_func: Callable
object: object
function: str
check_allowed: Callable[[dict[str, Any]], bool] | None
strict: bool
strict_core: bool
skip_for_tests: bool
_BLOCKING_CALLS: tuple[BlockingCall, ...] = (
BlockingCall(
original_func=HTTPConnection.putrequest,
object=HTTPConnection,
function="putrequest",
check_allowed=None,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=time.sleep,
object=time,
function="sleep",
check_allowed=_check_sleep_call_allowed,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.glob,
object=glob,
function="glob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.iglob,
object=glob,
function="iglob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.walk,
object=os,
function="walk",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.listdir,
object=os,
function="listdir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=os.scandir,
object=os,
function="scandir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=builtins.open,
object=builtins,
function="open",
check_allowed=_check_file_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=importlib.import_module,
object=importlib,
function="import_module",
check_allowed=_check_import_call_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
)
@dataclass(slots=True)
class BlockedCalls:
"""Class to track which calls are blocked."""
calls: set[BlockingCall]
_BLOCKED_CALLS = BlockedCalls(set())
def enable() -> None: def enable() -> None:
"""Enable the detection of blocking calls in the event loop.""" """Enable the detection of blocking calls in the event loop."""
calls = _BLOCKED_CALLS.calls
if calls:
raise RuntimeError("Blocking call detection is already enabled")
loop_thread_id = threading.get_ident() loop_thread_id = threading.get_ident()
# Prevent urllib3 and requests doing I/O in event loop for blocking_call in _BLOCKING_CALLS:
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign] if _IN_TESTS and blocking_call.skip_for_tests:
HTTPConnection.putrequest, loop_thread_id=loop_thread_id continue
)
# Prevent sleeping in event loop. protected_function = protect_loop(
time.sleep = protect_loop( blocking_call.original_func,
time.sleep, strict=blocking_call.strict,
check_allowed=_check_sleep_call_allowed, strict_core=blocking_call.strict_core,
loop_thread_id=loop_thread_id, check_allowed=blocking_call.check_allowed,
)
glob.glob = protect_loop(
glob.glob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
glob.iglob = protect_loop(
glob.iglob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
os.walk = protect_loop(
os.walk, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
if not _IN_TESTS:
# Prevent files being opened inside the event loop
os.listdir = protect_loop( # type: ignore[assignment]
os.listdir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
os.scandir = protect_loop( # type: ignore[assignment]
os.scandir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
builtins.open = protect_loop( # type: ignore[assignment]
builtins.open,
strict_core=False,
strict=False,
check_allowed=_check_file_allowed,
loop_thread_id=loop_thread_id,
)
# unittest uses `importlib.import_module` to do mocking
# so we cannot protect it if we are running tests
importlib.import_module = protect_loop(
importlib.import_module,
strict_core=False,
strict=False,
check_allowed=_check_import_call_allowed,
loop_thread_id=loop_thread_id, loop_thread_id=loop_thread_id,
) )
setattr(blocking_call.object, blocking_call.function, protected_function)
calls.add(blocking_call)

View File

@ -35,6 +35,8 @@ import requests_mock
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from typing_extensions import AsyncGenerator, Generator from typing_extensions import AsyncGenerator, Generator
from homeassistant import block_async_io
# Setup patching if dt_util time functions before any other Home Assistant imports # Setup patching if dt_util time functions before any other Home Assistant imports
from . import patch_time # noqa: F401, isort:skip from . import patch_time # noqa: F401, isort:skip
@ -1814,3 +1816,15 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall
def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion:
"""Return snapshot assertion fixture with the Home Assistant extension.""" """Return snapshot assertion fixture with the Home Assistant extension."""
return snapshot.use_extension(HomeAssistantSnapshotExtension) return snapshot.use_extension(HomeAssistantSnapshotExtension)
@pytest.fixture
def disable_block_async_io() -> Generator[Any, Any, None]:
"""Fixture to disable the loop protection from block_async_io."""
yield
calls = block_async_io._BLOCKED_CALLS.calls
for blocking_call in calls:
setattr(
blocking_call.object, blocking_call.function, blocking_call.original_func
)
calls.clear()

View File

@ -17,6 +17,11 @@ from homeassistant.core import HomeAssistant
from .common import extract_stack_to_frame from .common import extract_stack_to_frame
@pytest.fixture(autouse=True)
def disable_block_async_io(disable_block_async_io):
"""Disable the loop protection from block_async_io after each test."""
async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None:
"""Test time.sleep injected by the debugger is not reported.""" """Test time.sleep injected by the debugger is not reported."""
block_async_io.enable() block_async_io.enable()
@ -214,13 +219,25 @@ async def test_protect_loop_open(caplog: pytest.LogCaptureFixture) -> None:
async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in the event loop logs.""" """Test opening a file in the event loop logs."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close() open("/config/data_not_exist", encoding="utf8").close()
assert "Detected blocking call to open with args" in caplog.text assert "Detected blocking call to open with args" in caplog.text
async def test_enable_multiple_times(caplog: pytest.LogCaptureFixture) -> None:
"""Test trying to enable multiple times."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with pytest.raises(
RuntimeError, match="Blocking call detection is already enabled"
):
block_async_io.enable()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"path", "path",
[ [
@ -231,7 +248,8 @@ async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
) )
async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None: async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file by path in the event loop logs.""" """Test opening a file by path in the event loop logs."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
open(path, encoding="utf8").close() open(path, encoding="utf8").close()
@ -242,7 +260,8 @@ async def test_protect_loop_glob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test glob calls in the loop are logged."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
glob.glob("/dev/null") glob.glob("/dev/null")
assert "Detected blocking call to glob with args" in caplog.text assert "Detected blocking call to glob with args" in caplog.text
caplog.clear() caplog.clear()
@ -254,7 +273,8 @@ async def test_protect_loop_iglob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test iglob calls in the loop are logged.""" """Test iglob calls in the loop are logged."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
glob.iglob("/dev/null") glob.iglob("/dev/null")
assert "Detected blocking call to iglob with args" in caplog.text assert "Detected blocking call to iglob with args" in caplog.text
caplog.clear() caplog.clear()
@ -266,7 +286,8 @@ async def test_protect_loop_scandir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test glob calls in the loop are logged."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.scandir("/path/that/does/not/exists") os.scandir("/path/that/does/not/exists")
assert "Detected blocking call to scandir with args" in caplog.text assert "Detected blocking call to scandir with args" in caplog.text
@ -280,7 +301,8 @@ async def test_protect_loop_listdir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test listdir calls in the loop are logged.""" """Test listdir calls in the loop are logged."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.listdir("/path/that/does/not/exists") os.listdir("/path/that/does/not/exists")
assert "Detected blocking call to listdir with args" in caplog.text assert "Detected blocking call to listdir with args" in caplog.text
@ -293,8 +315,9 @@ async def test_protect_loop_listdir(
async def test_protect_loop_walk( async def test_protect_loop_walk(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test os.walk calls in the loop are logged."""
block_async_io.enable() with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.walk("/path/that/does/not/exists") os.walk("/path/that/does/not/exists")
assert "Detected blocking call to walk with args" in caplog.text assert "Detected blocking call to walk with args" in caplog.text
@ -302,3 +325,13 @@ async def test_protect_loop_walk(
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists") await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists")
assert "Detected blocking call to walk with args" not in caplog.text assert "Detected blocking call to walk with args" not in caplog.text
async def test_open_calls_ignored_in_tests(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in tests is ignored."""
assert block_async_io._IN_TESTS
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close()
assert "Detected blocking call to open with args" not in caplog.text

View File

@ -55,6 +55,11 @@ async def apply_stop_hass(stop_hass: None) -> None:
"""Make sure all hass are stopped.""" """Make sure all hass are stopped."""
@pytest.fixture(autouse=True)
def disable_block_async_io(disable_block_async_io):
"""Disable the loop protection from block_async_io after each test."""
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def mock_http_start_stop() -> Generator[None]: def mock_http_start_stop() -> Generator[None]:
"""Mock HTTP start and stop.""" """Mock HTTP start and stop."""