diff --git a/homeassistant/core.py b/homeassistant/core.py index 097e1ed7165..b4a86a46557 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -14,6 +14,7 @@ from collections.abc import ( Iterable, Mapping, ) +import concurrent.futures from contextlib import suppress from contextvars import ContextVar import datetime @@ -79,11 +80,7 @@ from .exceptions import ( ) from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior from .util import dt as dt_util, location, ulid as ulid_util -from .util.async_ import ( - fire_coroutine_threadsafe, - run_callback_threadsafe, - shutdown_run_callback_threadsafe, -) +from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe from .util.read_only_dict import ReadOnlyDict from .util.timeout import TimeoutManager from .util.unit_system import ( @@ -294,6 +291,7 @@ class HomeAssistant: self._stopped: asyncio.Event | None = None # Timeout handler for Core/Helper namespace self.timeout: TimeoutManager = TimeoutManager() + self._stop_future: concurrent.futures.Future[None] | None = None @property def is_running(self) -> bool: @@ -312,12 +310,14 @@ class HomeAssistant: For regular use, use "await hass.run()". """ # Register the async start - fire_coroutine_threadsafe(self.async_start(), self.loop) - + _future = asyncio.run_coroutine_threadsafe(self.async_start(), self.loop) # Run forever # Block until stopped _LOGGER.info("Starting Home Assistant core loop") self.loop.run_forever() + # The future is never retrieved but we still hold a reference to it + # to prevent the task from being garbage collected prematurely. + del _future return self.exit_code async def async_run(self, *, attach_signals: bool = True) -> int: @@ -682,7 +682,11 @@ class HomeAssistant: """Stop Home Assistant and shuts down all threads.""" if self.state == CoreState.not_running: # just ignore return - fire_coroutine_threadsafe(self.async_stop(), self.loop) + # The future is never retrieved, and we only hold a reference + # to it to prevent it from being garbage collected. + self._stop_future = asyncio.run_coroutine_threadsafe( + self.async_stop(), self.loop + ) async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None: """Stop Home Assistant and shuts down all threads. diff --git a/homeassistant/util/async_.py b/homeassistant/util/async_.py index f5164da4808..5b119a58c22 100644 --- a/homeassistant/util/async_.py +++ b/homeassistant/util/async_.py @@ -1,9 +1,9 @@ """Asyncio utilities.""" from __future__ import annotations -from asyncio import Semaphore, coroutines, ensure_future, gather, get_running_loop +from asyncio import Semaphore, gather, get_running_loop from asyncio.events import AbstractEventLoop -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable import concurrent.futures import functools import logging @@ -20,29 +20,6 @@ _R = TypeVar("_R") _P = ParamSpec("_P") -def fire_coroutine_threadsafe( - coro: Coroutine[Any, Any, Any], loop: AbstractEventLoop -) -> None: - """Submit a coroutine object to a given event loop. - - This method does not provide a way to retrieve the result and - is intended for fire-and-forget use. This reduces the - work involved to fire the function on the loop. - """ - ident = loop.__dict__.get("_thread_ident") - if ident is not None and ident == threading.get_ident(): - raise RuntimeError("Cannot be called from within the event loop") - - if not coroutines.iscoroutine(coro): - raise TypeError(f"A coroutine object is required: {coro}") - - def callback() -> None: - """Handle the firing of a coroutine.""" - ensure_future(coro, loop=loop) - - loop.call_soon_threadsafe(callback) - - def run_callback_threadsafe( loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any ) -> concurrent.futures.Future[_T]: diff --git a/tests/util/test_async.py b/tests/util/test_async.py index 48d9ee02f8d..7b0cc916ec7 100644 --- a/tests/util/test_async.py +++ b/tests/util/test_async.py @@ -10,43 +10,6 @@ from homeassistant.core import HomeAssistant from homeassistant.util import async_ as hasync -@patch("asyncio.coroutines.iscoroutine") -@patch("concurrent.futures.Future") -@patch("threading.get_ident") -def test_fire_coroutine_threadsafe_from_inside_event_loop( - mock_ident, _, mock_iscoroutine -) -> None: - """Testing calling fire_coroutine_threadsafe from inside an event loop.""" - coro = MagicMock() - loop = MagicMock() - - loop._thread_ident = None - mock_ident.return_value = 5 - mock_iscoroutine.return_value = True - hasync.fire_coroutine_threadsafe(coro, loop) - assert len(loop.call_soon_threadsafe.mock_calls) == 1 - - loop._thread_ident = 5 - mock_ident.return_value = 5 - mock_iscoroutine.return_value = True - with pytest.raises(RuntimeError): - hasync.fire_coroutine_threadsafe(coro, loop) - assert len(loop.call_soon_threadsafe.mock_calls) == 1 - - loop._thread_ident = 1 - mock_ident.return_value = 5 - mock_iscoroutine.return_value = False - with pytest.raises(TypeError): - hasync.fire_coroutine_threadsafe(coro, loop) - assert len(loop.call_soon_threadsafe.mock_calls) == 1 - - loop._thread_ident = 1 - mock_ident.return_value = 5 - mock_iscoroutine.return_value = True - hasync.fire_coroutine_threadsafe(coro, loop) - assert len(loop.call_soon_threadsafe.mock_calls) == 2 - - @patch("concurrent.futures.Future") @patch("threading.get_ident") def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None: