From 819dd27925c0cbcb45c85ce81c32746538d6a6fa Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 16 Nov 2020 15:43:48 +0100 Subject: [PATCH] Automatically clean up executor as part of closing loop (#43284) --- homeassistant/bootstrap.py | 10 +- homeassistant/core.py | 14 +-- homeassistant/runner.py | 28 +---- tests/common.py | 14 +-- tests/test_core.py | 220 ++++++++++++++++++------------------- 5 files changed, 120 insertions(+), 166 deletions(-) diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 0d63307a020..eff8a04ba92 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -15,11 +15,7 @@ import yarl from homeassistant import config as conf_util, config_entries, core, loader from homeassistant.components import http -from homeassistant.const import ( - EVENT_HOMEASSISTANT_STOP, - REQUIRED_NEXT_PYTHON_DATE, - REQUIRED_NEXT_PYTHON_VER, -) +from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.typing import ConfigType from homeassistant.setup import ( @@ -142,11 +138,9 @@ async def async_setup_hass( _LOGGER.warning("Detected that frontend did not load. Activating safe mode") # Ask integrations to shut down. It's messy but we can't # do a clean stop without knowing what is broken - hass.async_track_tasks() - hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP, {}) with contextlib.suppress(asyncio.TimeoutError): async with hass.timeout.async_timeout(10): - await hass.async_block_till_done() + await hass.async_stop() safe_mode = True old_config = hass.config diff --git a/homeassistant/core.py b/homeassistant/core.py index ed8ae854106..68f0b9a30b7 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -257,12 +257,9 @@ class HomeAssistant: fire_coroutine_threadsafe(self.async_start(), self.loop) # Run forever - try: - # Block until stopped - _LOGGER.info("Starting Home Assistant core loop") - self.loop.run_forever() - finally: - self.loop.close() + # Block until stopped + _LOGGER.info("Starting Home Assistant core loop") + self.loop.run_forever() return self.exit_code async def async_run(self, *, attach_signals: bool = True) -> int: @@ -559,16 +556,11 @@ class HomeAssistant: "Timed out waiting for shutdown stage 3 to complete, the shutdown will continue" ) - # Python 3.9+ and backported in runner.py - await self.loop.shutdown_default_executor() # type: ignore - self.exit_code = exit_code self.state = CoreState.stopped if self._stopped is not None: self._stopped.set() - else: - self.loop.stop() @attr.s(slots=True, frozen=True) diff --git a/homeassistant/runner.py b/homeassistant/runner.py index a5cf0f88a40..0f8bb836da5 100644 --- a/homeassistant/runner.py +++ b/homeassistant/runner.py @@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor import dataclasses import logging import sys -import threading from typing import Any, Dict, Optional from homeassistant import bootstrap @@ -77,29 +76,14 @@ class HassEventLoopPolicy(PolicyBase): # type: ignore loop.set_default_executor, "sets default executor on the event loop" ) - # Python 3.9+ - if hasattr(loop, "shutdown_default_executor"): - return loop + # Shut down executor when we shut down loop + orig_close = loop.close - # Copied from Python 3.9 source - def _do_shutdown(future: asyncio.Future) -> None: - try: - executor.shutdown(wait=True) - loop.call_soon_threadsafe(future.set_result, None) - except Exception as ex: # pylint: disable=broad-except - loop.call_soon_threadsafe(future.set_exception, ex) + def close() -> None: + executor.shutdown(wait=True) + orig_close() - async def shutdown_default_executor() -> None: - """Schedule the shutdown of the default executor.""" - future = loop.create_future() - thread = threading.Thread(target=_do_shutdown, args=(future,)) - thread.start() - try: - await future - finally: - thread.join() - - setattr(loop, "shutdown_default_executor", shutdown_default_executor) + loop.close = close # type: ignore return loop diff --git a/tests/common.py b/tests/common.py index 611becabe33..f0994308fe6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,7 +9,6 @@ from io import StringIO import json import logging import os -import sys import threading import time import uuid @@ -109,24 +108,21 @@ def get_test_config_dir(*add_path): def get_test_home_assistant(): """Return a Home Assistant object pointing at test config directory.""" - if sys.platform == "win32": - loop = asyncio.ProactorEventLoop() - else: - loop = asyncio.new_event_loop() - + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) hass = loop.run_until_complete(async_test_home_assistant(loop)) - stop_event = threading.Event() + loop_stop_event = threading.Event() def run_loop(): """Run event loop.""" # pylint: disable=protected-access loop._thread_ident = threading.get_ident() loop.run_forever() - stop_event.set() + loop_stop_event.set() orig_stop = hass.stop + hass._stopped = Mock(set=loop.stop) def start_hass(*mocks): """Start hass.""" @@ -135,7 +131,7 @@ def get_test_home_assistant(): def stop_hass(): """Stop hass.""" orig_stop() - stop_event.wait() + loop_stop_event.wait() loop.close() hass.start = start_hass diff --git a/tests/test_core.py b/tests/test_core.py index f08de049efa..5ebfc070ca2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -38,7 +38,11 @@ import homeassistant.util.dt as dt_util from homeassistant.util.unit_system import METRIC_SYSTEM from tests.async_mock import MagicMock, Mock, PropertyMock, patch -from tests.common import async_mock_service, get_test_home_assistant +from tests.common import ( + async_capture_events, + async_mock_service, + get_test_home_assistant, +) PST = pytz.timezone("America/Los_Angeles") @@ -151,22 +155,14 @@ def test_async_run_hass_job_delegates_non_async(): assert len(hass.async_add_hass_job.mock_calls) == 1 -def test_stage_shutdown(): +async def test_stage_shutdown(hass): """Simulate a shutdown, test calling stuff.""" - hass = get_test_home_assistant() - test_stop = [] - test_final_write = [] - test_close = [] - test_all = [] + test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP) + test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE) + test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE) + test_all = async_capture_events(hass, MATCH_ALL) - hass.bus.listen(EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event)) - hass.bus.listen( - EVENT_HOMEASSISTANT_FINAL_WRITE, lambda event: test_final_write.append(event) - ) - hass.bus.listen(EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event)) - hass.bus.listen("*", lambda event: test_all.append(event)) - - hass.stop() + await hass.async_stop() assert len(test_stop) == 1 assert len(test_close) == 1 @@ -341,147 +337,139 @@ def test_state_as_dict(): assert state.as_dict() is state.as_dict() -class TestEventBus(unittest.TestCase): - """Test EventBus methods.""" +async def test_add_remove_listener(hass): + """Test remove_listener method.""" + old_count = len(hass.bus.async_listeners()) - # pylint: disable=invalid-name - def setUp(self): - """Set up things to be run when tests are started.""" - self.hass = get_test_home_assistant() - self.bus = self.hass.bus + def listener(_): + pass - # pylint: disable=invalid-name - def tearDown(self): - """Stop down stuff we started.""" - self.hass.stop() + unsub = hass.bus.async_listen("test", listener) - def test_add_remove_listener(self): - """Test remove_listener method.""" - self.hass.allow_pool = False - old_count = len(self.bus.listeners) + assert old_count + 1 == len(hass.bus.async_listeners()) - def listener(_): - pass + # Remove listener + unsub() + assert old_count == len(hass.bus.async_listeners()) - unsub = self.bus.listen("test", listener) + # Should do nothing now + unsub() - assert old_count + 1 == len(self.bus.listeners) - # Remove listener - unsub() - assert old_count == len(self.bus.listeners) +async def test_unsubscribe_listener(hass): + """Test unsubscribe listener from returned function.""" + calls = [] - # Should do nothing now - unsub() + @ha.callback + def listener(event): + """Mock listener.""" + calls.append(event) - def test_unsubscribe_listener(self): - """Test unsubscribe listener from returned function.""" - calls = [] + unsub = hass.bus.async_listen("test", listener) - @ha.callback - def listener(event): - """Mock listener.""" - calls.append(event) + hass.bus.async_fire("test") + await hass.async_block_till_done() - unsub = self.bus.listen("test", listener) + assert len(calls) == 1 - self.bus.fire("test") - self.hass.block_till_done() + unsub() - assert len(calls) == 1 + hass.bus.async_fire("event") + await hass.async_block_till_done() - unsub() + assert len(calls) == 1 - self.bus.fire("event") - self.hass.block_till_done() - assert len(calls) == 1 +async def test_listen_once_event_with_callback(hass): + """Test listen_once_event method.""" + runs = [] - def test_listen_once_event_with_callback(self): - """Test listen_once_event method.""" - runs = [] + @ha.callback + def event_handler(event): + runs.append(event) - @ha.callback - def event_handler(event): - runs.append(event) + hass.bus.async_listen_once("test_event", event_handler) - self.bus.listen_once("test_event", event_handler) + hass.bus.async_fire("test_event") + # Second time it should not increase runs + hass.bus.async_fire("test_event") - self.bus.fire("test_event") - # Second time it should not increase runs - self.bus.fire("test_event") + await hass.async_block_till_done() + assert len(runs) == 1 - self.hass.block_till_done() - assert len(runs) == 1 - def test_listen_once_event_with_coroutine(self): - """Test listen_once_event method.""" - runs = [] +async def test_listen_once_event_with_coroutine(hass): + """Test listen_once_event method.""" + runs = [] - async def event_handler(event): - runs.append(event) + async def event_handler(event): + runs.append(event) - self.bus.listen_once("test_event", event_handler) + hass.bus.async_listen_once("test_event", event_handler) - self.bus.fire("test_event") - # Second time it should not increase runs - self.bus.fire("test_event") + hass.bus.async_fire("test_event") + # Second time it should not increase runs + hass.bus.async_fire("test_event") - self.hass.block_till_done() - assert len(runs) == 1 + await hass.async_block_till_done() + assert len(runs) == 1 - def test_listen_once_event_with_thread(self): - """Test listen_once_event method.""" - runs = [] - def event_handler(event): - runs.append(event) +async def test_listen_once_event_with_thread(hass): + """Test listen_once_event method.""" + runs = [] - self.bus.listen_once("test_event", event_handler) + def event_handler(event): + runs.append(event) - self.bus.fire("test_event") - # Second time it should not increase runs - self.bus.fire("test_event") + hass.bus.async_listen_once("test_event", event_handler) - self.hass.block_till_done() - assert len(runs) == 1 + hass.bus.async_fire("test_event") + # Second time it should not increase runs + hass.bus.async_fire("test_event") - def test_thread_event_listener(self): - """Test thread event listener.""" - thread_calls = [] + await hass.async_block_till_done() + assert len(runs) == 1 - def thread_listener(event): - thread_calls.append(event) - self.bus.listen("test_thread", thread_listener) - self.bus.fire("test_thread") - self.hass.block_till_done() - assert len(thread_calls) == 1 +async def test_thread_event_listener(hass): + """Test thread event listener.""" + thread_calls = [] - def test_callback_event_listener(self): - """Test callback event listener.""" - callback_calls = [] + def thread_listener(event): + thread_calls.append(event) - @ha.callback - def callback_listener(event): - callback_calls.append(event) + hass.bus.async_listen("test_thread", thread_listener) + hass.bus.async_fire("test_thread") + await hass.async_block_till_done() + assert len(thread_calls) == 1 - self.bus.listen("test_callback", callback_listener) - self.bus.fire("test_callback") - self.hass.block_till_done() - assert len(callback_calls) == 1 - def test_coroutine_event_listener(self): - """Test coroutine event listener.""" - coroutine_calls = [] +async def test_callback_event_listener(hass): + """Test callback event listener.""" + callback_calls = [] - async def coroutine_listener(event): - coroutine_calls.append(event) + @ha.callback + def callback_listener(event): + callback_calls.append(event) - self.bus.listen("test_coroutine", coroutine_listener) - self.bus.fire("test_coroutine") - self.hass.block_till_done() - assert len(coroutine_calls) == 1 + hass.bus.async_listen("test_callback", callback_listener) + hass.bus.async_fire("test_callback") + await hass.async_block_till_done() + assert len(callback_calls) == 1 + + +async def test_coroutine_event_listener(hass): + """Test coroutine event listener.""" + coroutine_calls = [] + + async def coroutine_listener(event): + coroutine_calls.append(event) + + hass.bus.async_listen("test_coroutine", coroutine_listener) + hass.bus.async_fire("test_coroutine") + await hass.async_block_till_done() + assert len(coroutine_calls) == 1 def test_state_init():