Refactor asyncio loop protection to improve performance (#117295)

This commit is contained in:
J. Nick Koston 2024-05-13 07:01:55 +09:00 committed by GitHub
parent aae39759d9
commit d06932bbc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 91 additions and 59 deletions

View File

@ -4,6 +4,7 @@ from contextlib import suppress
from http.client import HTTPConnection from http.client import HTTPConnection
import importlib import importlib
import sys import sys
import threading
import time import time
from typing import Any from typing import Any
@ -25,7 +26,7 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
# I/O and we are trying to avoid blocking calls. # I/O and we are trying to avoid blocking calls.
# #
# frame[0] is us # frame[0] is us
# frame[1] is check_loop # frame[1] is raise_for_blocking_call
# frame[2] is protected_loop_func # frame[2] is protected_loop_func
# frame[3] is the offender # frame[3] is the offender
with suppress(ValueError): with suppress(ValueError):
@ -35,14 +36,18 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
def enable() -> None: def enable() -> None:
"""Enable the detection of blocking calls in the event loop.""" """Enable the detection of blocking calls in the event loop."""
loop_thread_id = threading.get_ident()
# Prevent urllib3 and requests doing I/O in event loop # Prevent urllib3 and requests doing I/O in event loop
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign] HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign]
HTTPConnection.putrequest HTTPConnection.putrequest, loop_thread_id=loop_thread_id
) )
# Prevent sleeping in event loop. Non-strict since 2022.02 # Prevent sleeping in event loop. Non-strict since 2022.02
time.sleep = protect_loop( time.sleep = protect_loop(
time.sleep, strict=False, check_allowed=_check_sleep_call_allowed time.sleep,
strict=False,
check_allowed=_check_sleep_call_allowed,
loop_thread_id=loop_thread_id,
) )
# Currently disabled. pytz doing I/O when getting timezone. # Currently disabled. pytz doing I/O when getting timezone.
@ -57,4 +62,5 @@ def enable() -> None:
strict_core=False, strict_core=False,
strict=False, strict=False,
check_allowed=_check_import_call_allowed, check_allowed=_check_import_call_allowed,
loop_thread_id=loop_thread_id,
) )

View File

@ -1,5 +1,6 @@
"""A pool for sqlite connections.""" """A pool for sqlite connections."""
import asyncio
import logging import logging
import threading import threading
import traceback import traceback
@ -14,7 +15,7 @@ from sqlalchemy.pool import (
) )
from homeassistant.helpers.frame import report from homeassistant.helpers.frame import report
from homeassistant.util.loop import check_loop from homeassistant.util.loop import raise_for_blocking_call
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -86,15 +87,22 @@ class RecorderPool(SingletonThreadPool, NullPool):
if threading.get_ident() in self.recorder_and_worker_thread_ids: if threading.get_ident() in self.recorder_and_worker_thread_ids:
super().dispose() super().dispose()
def _do_get(self) -> ConnectionPoolEntry: def _do_get(self) -> ConnectionPoolEntry: # type: ignore[return]
if threading.get_ident() in self.recorder_and_worker_thread_ids: if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_get() return super()._do_get()
check_loop( try:
asyncio.get_running_loop()
except RuntimeError:
# Not in an event loop but not in the recorder or worker thread
# which is allowed but discouraged since its much slower
return self._do_get_db_connection_protected()
# In the event loop, raise an exception
raise_for_blocking_call(
self._do_get_db_connection_protected, self._do_get_db_connection_protected,
strict=True, strict=True,
advise_msg=ADVISE_MSG, advise_msg=ADVISE_MSG,
) )
return self._do_get_db_connection_protected() # raise_for_blocking_call will raise an exception
def _do_get_db_connection_protected(self) -> ConnectionPoolEntry: def _do_get_db_connection_protected(self) -> ConnectionPoolEntry:
report( report(

View File

@ -2,12 +2,12 @@
from __future__ import annotations from __future__ import annotations
from asyncio import get_running_loop
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import functools import functools
import linecache import linecache
import logging import logging
import threading
from typing import Any, ParamSpec, TypeVar from typing import Any, ParamSpec, TypeVar
from homeassistant.core import HomeAssistant, async_get_hass from homeassistant.core import HomeAssistant, async_get_hass
@ -31,7 +31,7 @@ def _get_line_from_cache(filename: str, lineno: int) -> str:
return (linecache.getline(filename, lineno) or "?").strip() return (linecache.getline(filename, lineno) or "?").strip()
def check_loop( def raise_for_blocking_call(
func: Callable[..., Any], func: Callable[..., Any],
check_allowed: Callable[[dict[str, Any]], bool] | None = None, check_allowed: Callable[[dict[str, Any]], bool] | None = None,
strict: bool = True, strict: bool = True,
@ -44,15 +44,6 @@ def check_loop(
The default advisory message is 'Use `await hass.async_add_executor_job()' The default advisory message is 'Use `await hass.async_add_executor_job()'
Set `advise_msg` to an alternate message if the solution differs. Set `advise_msg` to an alternate message if the solution differs.
""" """
try:
get_running_loop()
in_loop = True
except RuntimeError:
in_loop = False
if not in_loop:
return
if check_allowed is not None and check_allowed(mapped_args): if check_allowed is not None and check_allowed(mapped_args):
return return
@ -125,6 +116,7 @@ def check_loop(
def protect_loop( def protect_loop(
func: Callable[_P, _R], func: Callable[_P, _R],
loop_thread_id: int,
strict: bool = True, strict: bool = True,
strict_core: bool = True, strict_core: bool = True,
check_allowed: Callable[[dict[str, Any]], bool] | None = None, check_allowed: Callable[[dict[str, Any]], bool] | None = None,
@ -133,14 +125,15 @@ def protect_loop(
@functools.wraps(func) @functools.wraps(func)
def protected_loop_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: def protected_loop_func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
check_loop( if threading.get_ident() == loop_thread_id:
func, raise_for_blocking_call(
strict=strict, func,
strict_core=strict_core, strict=strict,
check_allowed=check_allowed, strict_core=strict_core,
args=args, check_allowed=check_allowed,
kwargs=kwargs, args=args,
) kwargs=kwargs,
)
return func(*args, **kwargs) return func(*args, **kwargs)
return protected_loop_func return protected_loop_func

View File

@ -159,14 +159,18 @@ async def test_shutdown_before_startup_finishes(
await recorder_helper.async_wait_recorder(hass) await recorder_helper.async_wait_recorder(hass)
instance = get_instance(hass) instance = get_instance(hass)
session = await hass.async_add_executor_job(instance.get_session) session = await instance.async_add_executor_job(instance.get_session)
with patch.object(instance, "engine"): with patch.object(instance, "engine"):
hass.bus.async_fire(EVENT_HOMEASSISTANT_FINAL_WRITE) hass.bus.async_fire(EVENT_HOMEASSISTANT_FINAL_WRITE)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_stop() await hass.async_stop()
run_info = await hass.async_add_executor_job(run_information_with_session, session) def _run_information_with_session():
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
return run_information_with_session(session)
run_info = await instance.async_add_executor_job(_run_information_with_session)
assert run_info.run_id == 1 assert run_info.run_id == 1
assert run_info.start is not None assert run_info.start is not None
@ -1693,7 +1697,8 @@ async def test_database_corruption_while_running(
await hass.async_block_till_done() await hass.async_block_till_done()
caplog.clear() caplog.clear()
original_start_time = get_instance(hass).recorder_runs_manager.recording_start instance = get_instance(hass)
original_start_time = instance.recorder_runs_manager.recording_start
hass.states.async_set("test.lost", "on", {}) hass.states.async_set("test.lost", "on", {})
@ -1737,11 +1742,11 @@ async def test_database_corruption_while_running(
assert db_states[0].event_id is None assert db_states[0].event_id is None
return db_states[0].to_native() return db_states[0].to_native()
state = await hass.async_add_executor_job(_get_last_state) state = await instance.async_add_executor_job(_get_last_state)
assert state.entity_id == "test.two" assert state.entity_id == "test.two"
assert state.state == "on" assert state.state == "on"
new_start_time = get_instance(hass).recorder_runs_manager.recording_start new_start_time = instance.recorder_runs_manager.recording_start
assert original_start_time < new_start_time assert original_start_time < new_start_time
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
@ -1850,7 +1855,7 @@ async def test_database_lock_and_unlock(
assert instance.unlock_database() assert instance.unlock_database()
await task await task
db_events = await hass.async_add_executor_job(_get_db_events) db_events = await instance.async_add_executor_job(_get_db_events)
assert len(db_events) == 1 assert len(db_events) == 1

View File

@ -9,12 +9,13 @@ import importlib
import json import json
from pathlib import Path from pathlib import Path
import sys import sys
import threading
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import SQLITE_URL_PREFIX from homeassistant.components.recorder import SQLITE_URL_PREFIX, get_instance
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.helpers import recorder as recorder_helper from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
@ -176,6 +177,7 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) ->
): ):
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
@ -358,6 +360,7 @@ def test_delete_duplicates_many(
): ):
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
@ -517,6 +520,7 @@ def test_delete_duplicates_non_identical(
): ):
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
@ -631,6 +635,7 @@ def test_delete_duplicates_short_term(
): ):
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)

View File

@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta
import os import os
from pathlib import Path from pathlib import Path
import sqlite3 import sqlite3
import threading
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
@ -843,9 +844,7 @@ async def test_periodic_db_cleanups(
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);" assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
@patch("homeassistant.components.recorder.pool.check_loop")
async def test_write_lock_db( async def test_write_lock_db(
skip_check_loop,
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
@ -864,6 +863,7 @@ async def test_write_lock_db(
with instance.engine.connect() as connection: with instance.engine.connect() as connection:
connection.execute(text("DROP TABLE events;")) connection.execute(text("DROP TABLE events;"))
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
with util.write_lock_db_sqlite(instance), pytest.raises(OperationalError): with util.write_lock_db_sqlite(instance), pytest.raises(OperationalError):
# Database should be locked now, try writing SQL command # Database should be locked now, try writing SQL command
# This needs to be called in another thread since # This needs to be called in another thread since
@ -872,7 +872,7 @@ async def test_write_lock_db(
# in the same thread as the one holding the lock since it # in the same thread as the one holding the lock since it
# would be allowed to proceed as the goal is to prevent # would be allowed to proceed as the goal is to prevent
# all the other threads from accessing the database # all the other threads from accessing the database
await hass.async_add_executor_job(_drop_table) await instance.async_add_executor_job(_drop_table)
def test_is_second_sunday() -> None: def test_is_second_sunday() -> None:

View File

@ -2,11 +2,13 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
import threading
from unittest.mock import patch from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.history import get_significant_states from homeassistant.components.recorder.history import get_significant_states
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
get_latest_short_term_statistics_with_session, get_latest_short_term_statistics_with_session,
@ -57,6 +59,7 @@ def test_compile_missing_statistics(
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "sensor", {}) setup_component(hass, "sensor", {})
setup_component(hass, "recorder", {"recorder": config}) setup_component(hass, "recorder", {"recorder": config})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
@ -98,6 +101,7 @@ def test_compile_missing_statistics(
setup_component(hass, "sensor", {}) setup_component(hass, "sensor", {})
hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES) hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES)
setup_component(hass, "recorder", {"recorder": config}) setup_component(hass, "recorder", {"recorder": config})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)

View File

@ -1,9 +1,11 @@
"""Tests for async util methods from Python source.""" """Tests for async util methods from Python source."""
import threading
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from homeassistant.core import HomeAssistant
from homeassistant.util import loop as haloop from homeassistant.util import loop as haloop
from tests.common import extract_stack_to_frame from tests.common import extract_stack_to_frame
@ -13,22 +15,24 @@ def banned_function():
"""Mock banned function.""" """Mock banned function."""
async def test_check_loop_async() -> None: async def test_raise_for_blocking_call_async() -> None:
"""Test check_loop detects when called from event loop without integration context.""" """Test raise_for_blocking_call detects when called from event loop without integration context."""
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
haloop.check_loop(banned_function) haloop.raise_for_blocking_call(banned_function)
async def test_check_loop_async_non_strict_core( async def test_raise_for_blocking_call_async_non_strict_core(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test non_strict_core check_loop detects from event loop without integration context.""" """Test non_strict_core raise_for_blocking_call detects from event loop without integration context."""
haloop.check_loop(banned_function, strict_core=False) haloop.raise_for_blocking_call(banned_function, strict_core=False)
assert "Detected blocking call to banned_function" in caplog.text assert "Detected blocking call to banned_function" in caplog.text
async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> None: async def test_raise_for_blocking_call_async_integration(
"""Test check_loop detects and raises when called from event loop from integration context.""" caplog: pytest.LogCaptureFixture,
) -> None:
"""Test raise_for_blocking_call detects and raises when called from event loop from integration context."""
frames = extract_stack_to_frame( frames = extract_stack_to_frame(
[ [
Mock( Mock(
@ -67,7 +71,7 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
return_value=frames, return_value=frames,
), ),
): ):
haloop.check_loop(banned_function) haloop.raise_for_blocking_call(banned_function)
assert ( assert (
"Detected blocking call to banned_function inside the event loop by integration" "Detected blocking call to banned_function inside the event loop by integration"
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on " " 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
@ -77,10 +81,10 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
) )
async def test_check_loop_async_integration_non_strict( async def test_raise_for_blocking_call_async_integration_non_strict(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test check_loop detects when called from event loop from integration context.""" """Test raise_for_blocking_call detects when called from event loop from integration context."""
frames = extract_stack_to_frame( frames = extract_stack_to_frame(
[ [
Mock( Mock(
@ -118,7 +122,7 @@ async def test_check_loop_async_integration_non_strict(
return_value=frames, return_value=frames,
), ),
): ):
haloop.check_loop(banned_function, strict=False) haloop.raise_for_blocking_call(banned_function, strict=False)
assert ( assert (
"Detected blocking call to banned_function inside the event loop by integration" "Detected blocking call to banned_function inside the event loop by integration"
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on " " 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
@ -128,8 +132,10 @@ async def test_check_loop_async_integration_non_strict(
) )
async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None: async def test_raise_for_blocking_call_async_custom(
"""Test check_loop detects when called from event loop with custom component context.""" caplog: pytest.LogCaptureFixture,
) -> None:
"""Test raise_for_blocking_call detects when called from event loop with custom component context."""
frames = extract_stack_to_frame( frames = extract_stack_to_frame(
[ [
Mock( Mock(
@ -168,7 +174,7 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
return_value=frames, return_value=frames,
), ),
): ):
haloop.check_loop(banned_function) haloop.raise_for_blocking_call(banned_function)
assert ( assert (
"Detected blocking call to banned_function inside the event loop by custom " "Detected blocking call to banned_function inside the event loop by custom "
"integration 'hue' at custom_components/hue/light.py, line 23: self.light.is_on" "integration 'hue' at custom_components/hue/light.py, line 23: self.light.is_on"
@ -178,18 +184,23 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
) in caplog.text ) in caplog.text
def test_check_loop_sync(caplog: pytest.LogCaptureFixture) -> None: async def test_raise_for_blocking_call_sync(
"""Test check_loop does nothing when called from thread.""" hass: HomeAssistant, caplog: pytest.LogCaptureFixture
haloop.check_loop(banned_function) ) -> None:
"""Test raise_for_blocking_call does nothing when called from thread."""
func = haloop.protect_loop(banned_function, threading.get_ident())
await hass.async_add_executor_job(func)
assert "Detected blocking call inside the event loop" not in caplog.text assert "Detected blocking call inside the event loop" not in caplog.text
def test_protect_loop_sync() -> None: async def test_protect_loop_async() -> None:
"""Test protect_loop calls check_loop.""" """Test protect_loop calls raise_for_blocking_call."""
func = Mock() func = Mock()
with patch("homeassistant.util.loop.check_loop") as mock_check_loop: with patch(
haloop.protect_loop(func)(1, test=2) "homeassistant.util.loop.raise_for_blocking_call"
mock_check_loop.assert_called_once_with( ) as mock_raise_for_blocking_call:
haloop.protect_loop(func, threading.get_ident())(1, test=2)
mock_raise_for_blocking_call.assert_called_once_with(
func, func,
strict=True, strict=True,
args=(1,), args=(1,),