mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 02:07:09 +00:00
Ensure shutdown does not deadlock (#49282)
This commit is contained in:
parent
afd79a675c
commit
04a0ca14e0
@ -87,7 +87,7 @@ if TYPE_CHECKING:
|
|||||||
from homeassistant.config_entries import ConfigEntries
|
from homeassistant.config_entries import ConfigEntries
|
||||||
|
|
||||||
|
|
||||||
STAGE_1_SHUTDOWN_TIMEOUT = 120
|
STAGE_1_SHUTDOWN_TIMEOUT = 100
|
||||||
STAGE_2_SHUTDOWN_TIMEOUT = 60
|
STAGE_2_SHUTDOWN_TIMEOUT = 60
|
||||||
STAGE_3_SHUTDOWN_TIMEOUT = 30
|
STAGE_3_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
|
@ -2,14 +2,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant import bootstrap
|
from homeassistant import bootstrap
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers.frame import warn_use
|
from homeassistant.helpers.frame import warn_use
|
||||||
|
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
||||||
|
from homeassistant.util.thread import deadlock_safe_shutdown
|
||||||
|
|
||||||
# mypy: disallow-any-generics
|
# mypy: disallow-any-generics
|
||||||
|
|
||||||
@ -64,7 +66,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
|
|||||||
if self.debug:
|
if self.debug:
|
||||||
loop.set_debug(True)
|
loop.set_debug(True)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(
|
executor = InterruptibleThreadPoolExecutor(
|
||||||
thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS
|
thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS
|
||||||
)
|
)
|
||||||
loop.set_default_executor(executor)
|
loop.set_default_executor(executor)
|
||||||
@ -76,7 +78,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
|
|||||||
orig_close = loop.close
|
orig_close = loop.close
|
||||||
|
|
||||||
def close() -> None:
|
def close() -> None:
|
||||||
executor.shutdown(wait=True)
|
executor.logged_shutdown()
|
||||||
orig_close()
|
orig_close()
|
||||||
|
|
||||||
loop.close = close # type: ignore
|
loop.close = close # type: ignore
|
||||||
@ -104,6 +106,9 @@ async def setup_and_run_hass(runtime_config: RuntimeConfig) -> int:
|
|||||||
if hass is None:
|
if hass is None:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
# threading._shutdown can deadlock forever
|
||||||
|
threading._shutdown = deadlock_safe_shutdown # type: ignore[attr-defined] # pylint: disable=protected-access
|
||||||
|
|
||||||
return await hass.async_run()
|
return await hass.async_run()
|
||||||
|
|
||||||
|
|
||||||
|
108
homeassistant/util/executor.py
Normal file
108
homeassistant/util/executor.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""Executor util helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
from threading import Thread
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from homeassistant.util.thread import async_raise
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_LOG_ATTEMPTS = 2
|
||||||
|
|
||||||
|
_JOIN_ATTEMPTS = 10
|
||||||
|
|
||||||
|
EXECUTOR_SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _log_thread_running_at_shutdown(name: str, ident: int) -> None:
|
||||||
|
"""Log the stack of a thread that was still running at shutdown."""
|
||||||
|
frames = sys._current_frames() # pylint: disable=protected-access
|
||||||
|
stack = frames.get(ident)
|
||||||
|
formatted_stack = traceback.format_stack(stack)
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Thread[%s] is still running at shutdown: %s",
|
||||||
|
name,
|
||||||
|
"".join(formatted_stack).strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def join_or_interrupt_threads(
|
||||||
|
threads: set[Thread], timeout: float, log: bool
|
||||||
|
) -> set[Thread]:
|
||||||
|
"""Attempt to join or interrupt a set of threads."""
|
||||||
|
joined = set()
|
||||||
|
timeout_per_thread = timeout / len(threads)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join(timeout=timeout_per_thread)
|
||||||
|
|
||||||
|
if not thread.is_alive() or thread.ident is None:
|
||||||
|
joined.add(thread)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if log:
|
||||||
|
_log_thread_running_at_shutdown(thread.name, thread.ident)
|
||||||
|
|
||||||
|
async_raise(thread.ident, SystemExit)
|
||||||
|
|
||||||
|
return joined
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
|
||||||
|
"""A ThreadPoolExecutor instance that will not deadlock on shutdown."""
|
||||||
|
|
||||||
|
def logged_shutdown(self) -> None:
|
||||||
|
"""Shutdown backport from cpython 3.9 with interrupt support added."""
|
||||||
|
with self._shutdown_lock: # type: ignore[attr-defined]
|
||||||
|
self._shutdown = True
|
||||||
|
# Drain all work items from the queue, and then cancel their
|
||||||
|
# associated futures.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
work_item = self._work_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
if work_item is not None:
|
||||||
|
work_item.future.cancel()
|
||||||
|
# Send a wake-up to prevent threads calling
|
||||||
|
# _work_queue.get(block=True) from permanently blocking.
|
||||||
|
self._work_queue.put(None)
|
||||||
|
|
||||||
|
# The above code is backported from python 3.9
|
||||||
|
#
|
||||||
|
# For maintainability join_threads_or_timeout is
|
||||||
|
# a separate function since it is not a backport from
|
||||||
|
# cpython itself
|
||||||
|
#
|
||||||
|
self.join_threads_or_timeout()
|
||||||
|
|
||||||
|
def join_threads_or_timeout(self) -> None:
|
||||||
|
"""Join threads or timeout."""
|
||||||
|
remaining_threads = set(self._threads) # type: ignore[attr-defined]
|
||||||
|
start_time = time.monotonic()
|
||||||
|
timeout_remaining: float = EXECUTOR_SHUTDOWN_TIMEOUT
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not remaining_threads:
|
||||||
|
return
|
||||||
|
|
||||||
|
attempt += 1
|
||||||
|
|
||||||
|
remaining_threads -= join_or_interrupt_threads(
|
||||||
|
remaining_threads,
|
||||||
|
timeout_remaining / _JOIN_ATTEMPTS,
|
||||||
|
attempt <= MAX_LOG_ATTEMPTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout_remaining = EXECUTOR_SHUTDOWN_TIMEOUT - (
|
||||||
|
time.monotonic() - start_time
|
||||||
|
)
|
||||||
|
if timeout_remaining <= 0:
|
||||||
|
return
|
@ -1,16 +1,45 @@
|
|||||||
"""Threading util helpers."""
|
"""Threading util helpers."""
|
||||||
import ctypes
|
import ctypes
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
THREADING_SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
def _async_raise(tid: int, exctype: Any) -> None:
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def deadlock_safe_shutdown() -> None:
|
||||||
|
"""Shutdown that will not deadlock."""
|
||||||
|
# threading._shutdown can deadlock forever
|
||||||
|
# see https://github.com/justengel/continuous_threading#shutdown-update
|
||||||
|
# for additional detail
|
||||||
|
remaining_threads = [
|
||||||
|
thread
|
||||||
|
for thread in threading.enumerate()
|
||||||
|
if thread is not threading.main_thread()
|
||||||
|
and not thread.daemon
|
||||||
|
and thread.is_alive()
|
||||||
|
]
|
||||||
|
|
||||||
|
if not remaining_threads:
|
||||||
|
return
|
||||||
|
|
||||||
|
timeout_per_thread = THREADING_SHUTDOWN_TIMEOUT / len(remaining_threads)
|
||||||
|
for thread in remaining_threads:
|
||||||
|
try:
|
||||||
|
thread.join(timeout_per_thread)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.warning("Failed to join thread: %s", err)
|
||||||
|
|
||||||
|
|
||||||
|
def async_raise(tid: int, exctype: Any) -> None:
|
||||||
"""Raise an exception in the threads with id tid."""
|
"""Raise an exception in the threads with id tid."""
|
||||||
if not inspect.isclass(exctype):
|
if not inspect.isclass(exctype):
|
||||||
raise TypeError("Only types can be raised (not instances)")
|
raise TypeError("Only types can be raised (not instances)")
|
||||||
|
|
||||||
c_tid = ctypes.c_long(tid)
|
c_tid = ctypes.c_ulong(tid) # changed in python 3.7+
|
||||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype))
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype))
|
||||||
|
|
||||||
if res == 1:
|
if res == 1:
|
||||||
@ -33,4 +62,4 @@ class ThreadWithException(threading.Thread):
|
|||||||
def raise_exc(self, exctype: Any) -> None:
|
def raise_exc(self, exctype: Any) -> None:
|
||||||
"""Raise the given exception type in the context of this thread."""
|
"""Raise the given exception type in the context of this thread."""
|
||||||
assert self.ident
|
assert self.ident
|
||||||
_async_raise(self.ident, exctype)
|
async_raise(self.ident, exctype)
|
||||||
|
39
tests/test_runner.py
Normal file
39
tests/test_runner.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
"""Test the runner."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant import core, runner
|
||||||
|
from homeassistant.util import executor, thread
|
||||||
|
|
||||||
|
# https://github.com/home-assistant/supervisor/blob/main/supervisor/docker/homeassistant.py
|
||||||
|
SUPERVISOR_HARD_TIMEOUT = 220
|
||||||
|
|
||||||
|
TIMEOUT_SAFETY_MARGIN = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cumulative_shutdown_timeout_less_than_supervisor():
|
||||||
|
"""Verify the cumulative shutdown timeout is at least 10s less than the supervisor."""
|
||||||
|
assert (
|
||||||
|
core.STAGE_1_SHUTDOWN_TIMEOUT
|
||||||
|
+ core.STAGE_2_SHUTDOWN_TIMEOUT
|
||||||
|
+ core.STAGE_3_SHUTDOWN_TIMEOUT
|
||||||
|
+ executor.EXECUTOR_SHUTDOWN_TIMEOUT
|
||||||
|
+ thread.THREADING_SHUTDOWN_TIMEOUT
|
||||||
|
+ TIMEOUT_SAFETY_MARGIN
|
||||||
|
<= SUPERVISOR_HARD_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setup_and_run_hass(hass, tmpdir):
|
||||||
|
"""Test we can setup and run."""
|
||||||
|
test_dir = tmpdir.mkdir("config")
|
||||||
|
default_config = runner.RuntimeConfig(test_dir)
|
||||||
|
|
||||||
|
with patch("homeassistant.bootstrap.async_setup_hass", return_value=hass), patch(
|
||||||
|
"threading._shutdown"
|
||||||
|
), patch("homeassistant.core.HomeAssistant.async_run") as mock_run:
|
||||||
|
await runner.setup_and_run_hass(default_config)
|
||||||
|
assert threading._shutdown == thread.deadlock_safe_shutdown
|
||||||
|
|
||||||
|
assert mock_run.called
|
91
tests/util/test_executor.py
Normal file
91
tests/util/test_executor.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""Test Home Assistant executor util."""
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.util import executor
|
||||||
|
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
async def test_executor_shutdown_can_interrupt_threads(caplog):
|
||||||
|
"""Test that the executor shutdown can interrupt threads."""
|
||||||
|
|
||||||
|
iexecutor = InterruptibleThreadPoolExecutor()
|
||||||
|
|
||||||
|
def _loop_sleep_in_executor():
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
sleep_futures = []
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
sleep_futures.append(iexecutor.submit(_loop_sleep_in_executor))
|
||||||
|
|
||||||
|
iexecutor.logged_shutdown()
|
||||||
|
|
||||||
|
for future in sleep_futures:
|
||||||
|
with pytest.raises((concurrent.futures.CancelledError, SystemExit)):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
assert "is still running at shutdown" in caplog.text
|
||||||
|
assert "time.sleep(0.1)" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_executor_shutdown_only_logs_max_attempts(caplog):
|
||||||
|
"""Test that the executor shutdown will only log max attempts."""
|
||||||
|
|
||||||
|
iexecutor = InterruptibleThreadPoolExecutor()
|
||||||
|
|
||||||
|
def _loop_sleep_in_executor():
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
iexecutor.submit(_loop_sleep_in_executor)
|
||||||
|
|
||||||
|
with patch.object(executor, "EXECUTOR_SHUTDOWN_TIMEOUT", 0.3):
|
||||||
|
iexecutor.logged_shutdown()
|
||||||
|
|
||||||
|
assert "time.sleep(0.2)" in caplog.text
|
||||||
|
assert (
|
||||||
|
caplog.text.count("is still running at shutdown") == executor.MAX_LOG_ATTEMPTS
|
||||||
|
)
|
||||||
|
iexecutor.logged_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_executor_shutdown_does_not_log_shutdown_on_first_attempt(caplog):
|
||||||
|
"""Test that the executor shutdown does not log on first attempt."""
|
||||||
|
|
||||||
|
iexecutor = InterruptibleThreadPoolExecutor()
|
||||||
|
|
||||||
|
def _do_nothing():
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
iexecutor.submit(_do_nothing)
|
||||||
|
|
||||||
|
iexecutor.logged_shutdown()
|
||||||
|
|
||||||
|
assert "is still running at shutdown" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_overall_timeout_reached(caplog):
|
||||||
|
"""Test that shutdown moves on when the overall timeout is reached."""
|
||||||
|
|
||||||
|
iexecutor = InterruptibleThreadPoolExecutor()
|
||||||
|
|
||||||
|
def _loop_sleep_in_executor():
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
for _ in range(6):
|
||||||
|
iexecutor.submit(_loop_sleep_in_executor)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
with patch.object(executor, "EXECUTOR_SHUTDOWN_TIMEOUT", 0.5):
|
||||||
|
iexecutor.logged_shutdown()
|
||||||
|
finish = time.monotonic()
|
||||||
|
|
||||||
|
assert finish - start < 1
|
||||||
|
|
||||||
|
iexecutor.logged_shutdown()
|
@ -1,9 +1,11 @@
|
|||||||
"""Test Home Assistant thread utils."""
|
"""Test Home Assistant thread utils."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.util import thread
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
from homeassistant.util.thread import ThreadWithException
|
from homeassistant.util.thread import ThreadWithException
|
||||||
|
|
||||||
@ -53,3 +55,57 @@ async def test_thread_fails_raise(hass):
|
|||||||
|
|
||||||
class _EmptyClass:
|
class _EmptyClass:
|
||||||
"""An empty class."""
|
"""An empty class."""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_deadlock_safe_shutdown_no_threads():
|
||||||
|
"""Test we can shutdown without deadlock without any threads to join."""
|
||||||
|
|
||||||
|
dead_thread_mock = Mock(
|
||||||
|
join=Mock(), daemon=False, is_alive=Mock(return_value=False)
|
||||||
|
)
|
||||||
|
daemon_thread_mock = Mock(
|
||||||
|
join=Mock(), daemon=True, is_alive=Mock(return_value=True)
|
||||||
|
)
|
||||||
|
mock_threads = [
|
||||||
|
dead_thread_mock,
|
||||||
|
daemon_thread_mock,
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("homeassistant.util.threading.enumerate", return_value=mock_threads):
|
||||||
|
thread.deadlock_safe_shutdown()
|
||||||
|
|
||||||
|
assert not dead_thread_mock.join.called
|
||||||
|
assert not daemon_thread_mock.join.called
|
||||||
|
|
||||||
|
|
||||||
|
async def test_deadlock_safe_shutdown():
|
||||||
|
"""Test we can shutdown without deadlock."""
|
||||||
|
|
||||||
|
normal_thread_mock = Mock(
|
||||||
|
join=Mock(), daemon=False, is_alive=Mock(return_value=True)
|
||||||
|
)
|
||||||
|
dead_thread_mock = Mock(
|
||||||
|
join=Mock(), daemon=False, is_alive=Mock(return_value=False)
|
||||||
|
)
|
||||||
|
daemon_thread_mock = Mock(
|
||||||
|
join=Mock(), daemon=True, is_alive=Mock(return_value=True)
|
||||||
|
)
|
||||||
|
exception_thread_mock = Mock(
|
||||||
|
join=Mock(side_effect=Exception), daemon=False, is_alive=Mock(return_value=True)
|
||||||
|
)
|
||||||
|
mock_threads = [
|
||||||
|
normal_thread_mock,
|
||||||
|
dead_thread_mock,
|
||||||
|
daemon_thread_mock,
|
||||||
|
exception_thread_mock,
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("homeassistant.util.threading.enumerate", return_value=mock_threads):
|
||||||
|
thread.deadlock_safe_shutdown()
|
||||||
|
|
||||||
|
expected_timeout = thread.THREADING_SHUTDOWN_TIMEOUT / 2
|
||||||
|
|
||||||
|
assert normal_thread_mock.join.call_args[0] == (expected_timeout,)
|
||||||
|
assert not dead_thread_mock.join.called
|
||||||
|
assert not daemon_thread_mock.join.called
|
||||||
|
assert exception_thread_mock.join.call_args[0] == (expected_timeout,)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user