Keep task references while running (#87970)

* Keep task references while running

* Update pilight tests pointing at correct logger call

* Fix graphite tests

* Fix profiler tests

* More graphite test fixes

* Remove extra sleep

* Fix tests

* Shutdown background tasks as part of stage 1

* Remove unnecessary sleep in test

* Remove unused method on mock hass

* Skip on cancelled too

* Remove background tasks

* Test trigger variables without actually sleeping

* Fix graphite

* One more graphite grrrrrrr
This commit is contained in:
Paulus Schoutsen 2023-02-13 23:16:59 -05:00 committed by GitHub
parent e41af8928b
commit d54f59478f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 97 additions and 210 deletions

View File

@ -69,7 +69,7 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
else: else:
_LOGGER.debug("No connection check for UDP possible") _LOGGER.debug("No connection check for UDP possible")
GraphiteFeeder(hass, host, port, protocol, prefix) hass.data[DOMAIN] = GraphiteFeeder(hass, host, port, protocol, prefix)
return True return True

View File

@ -278,8 +278,7 @@ class HomeAssistant:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize new Home Assistant object.""" """Initialize new Home Assistant object."""
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self._pending_tasks: list[asyncio.Future[Any]] = [] self._tasks: set[asyncio.Future[Any]] = set()
self._track_task = True
self.bus = EventBus(self) self.bus = EventBus(self)
self.services = ServiceRegistry(self) self.services = ServiceRegistry(self)
self.states = StateMachine(self.bus, self.loop) self.states = StateMachine(self.bus, self.loop)
@ -355,12 +354,14 @@ class HomeAssistant:
self.bus.async_fire(EVENT_CORE_CONFIG_UPDATE) self.bus.async_fire(EVENT_CORE_CONFIG_UPDATE)
self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.bus.async_fire(EVENT_HOMEASSISTANT_START)
try: if not self._tasks:
# Only block for EVENT_HOMEASSISTANT_START listener pending: set[asyncio.Future[Any]] | None = None
self.async_stop_track_tasks() else:
async with self.timeout.async_timeout(TIMEOUT_EVENT_START): _done, pending = await asyncio.wait(
await self.async_block_till_done() self._tasks, timeout=TIMEOUT_EVENT_START
except asyncio.TimeoutError: )
if pending:
_LOGGER.warning( _LOGGER.warning(
( (
"Something is blocking Home Assistant from wrapping up the start up" "Something is blocking Home Assistant from wrapping up the start up"
@ -496,9 +497,8 @@ class HomeAssistant:
hassjob.target = cast(Callable[..., _R], hassjob.target) hassjob.target = cast(Callable[..., _R], hassjob.target)
task = self.loop.run_in_executor(None, hassjob.target, *args) task = self.loop.run_in_executor(None, hassjob.target, *args)
# If a task is scheduled self._tasks.add(task)
if self._track_task: task.add_done_callback(self._tasks.remove)
self._pending_tasks.append(task)
return task return task
@ -518,9 +518,8 @@ class HomeAssistant:
target: target to call. target: target to call.
""" """
task = self.loop.create_task(target) task = self.loop.create_task(target)
self._tasks.add(task)
if self._track_task: task.add_done_callback(self._tasks.remove)
self._pending_tasks.append(task)
return task return task
@ -530,23 +529,11 @@ class HomeAssistant:
) -> asyncio.Future[_T]: ) -> asyncio.Future[_T]:
"""Add an executor job from within the event loop.""" """Add an executor job from within the event loop."""
task = self.loop.run_in_executor(None, target, *args) task = self.loop.run_in_executor(None, target, *args)
self._tasks.add(task)
# If a task is scheduled task.add_done_callback(self._tasks.remove)
if self._track_task:
self._pending_tasks.append(task)
return task return task
@callback
def async_track_tasks(self) -> None:
"""Track tasks so you can wait for all tasks to be done."""
self._track_task = True
@callback
def async_stop_track_tasks(self) -> None:
"""Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False
@overload @overload
@callback @callback
def async_run_hass_job( def async_run_hass_job(
@ -640,29 +627,25 @@ class HomeAssistant:
# To flush out any call_soon_threadsafe # To flush out any call_soon_threadsafe
await asyncio.sleep(0) await asyncio.sleep(0)
start_time: float | None = None start_time: float | None = None
current_task = asyncio.current_task()
while self._pending_tasks: while tasks := [task for task in self._tasks if task is not current_task]:
pending = [task for task in self._pending_tasks if not task.done()] await self._await_and_log_pending(tasks)
self._pending_tasks.clear()
if pending:
await self._await_and_log_pending(pending)
if start_time is None: if start_time is None:
# Avoid calling monotonic() until we know # Avoid calling monotonic() until we know
# we may need to start logging blocked tasks. # we may need to start logging blocked tasks.
start_time = 0 start_time = 0
elif start_time == 0: elif start_time == 0:
# If we have waited twice then we set the start # If we have waited twice then we set the start
# time # time
start_time = monotonic() start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT: elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks # We have waited at least three loops and new tasks
# continue to block. At this point we start # continue to block. At this point we start
# logging all waiting tasks. # logging all waiting tasks.
for task in pending: for task in tasks:
_LOGGER.debug("Waiting for task: %s", task) _LOGGER.debug("Waiting for task: %s", task)
else:
await asyncio.sleep(0)
async def _await_and_log_pending(self, pending: Collection[Awaitable[Any]]) -> None: async def _await_and_log_pending(self, pending: Collection[Awaitable[Any]]) -> None:
"""Await and log tasks that take a long time.""" """Await and log tasks that take a long time."""
@ -706,7 +689,6 @@ class HomeAssistant:
# stage 1 # stage 1
self.state = CoreState.stopping self.state = CoreState.stopping
self.async_track_tasks()
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP) self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
try: try:
async with self.timeout.async_timeout(STAGE_1_SHUTDOWN_TIMEOUT): async with self.timeout.async_timeout(STAGE_1_SHUTDOWN_TIMEOUT):

View File

@ -3,14 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from collections.abc import ( from collections.abc import Awaitable, Callable, Generator, Mapping, Sequence
Awaitable,
Callable,
Collection,
Generator,
Mapping,
Sequence,
)
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import functools as ft import functools as ft
@ -22,8 +15,6 @@ import os
import pathlib import pathlib
import threading import threading
import time import time
from time import monotonic
import types
from typing import Any, NoReturn from typing import Any, NoReturn
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@ -51,7 +42,6 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
) )
from homeassistant.core import ( from homeassistant.core import (
BLOCK_LOG_TIMEOUT,
CoreState, CoreState,
Event, Event,
HomeAssistant, HomeAssistant,
@ -221,76 +211,9 @@ async def async_test_home_assistant(event_loop, load_registries=True):
return orig_async_create_task(coroutine) return orig_async_create_task(coroutine)
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: float | None = None
while len(self._pending_tasks) > max_remaining_tasks:
pending: Collection[Awaitable[Any]] = [
task for task in self._pending_tasks if not task.done()
]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)
if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)
async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0
return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED
while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
return []
hass.async_add_job = async_add_job hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
@ -328,17 +251,6 @@ async def async_test_home_assistant(event_loop, load_registries=True):
hass.state = CoreState.running hass.state = CoreState.running
# Mock async_start
orig_start = hass.async_start
async def mock_async_start():
"""Start the mocking."""
# We only mock time during tests and we want to track tasks
with patch.object(hass, "async_stop_track_tasks"):
await orig_start()
hass.async_start = mock_async_start
@callback @callback
def clear_instance(event): def clear_instance(event):
"""Clear global instance.""" """Clear global instance."""

View File

@ -239,5 +239,6 @@ async def test_initialize_start(hass: HomeAssistant) -> None:
) as mock_activate: ) as mock_activate:
hass.bus.fire(EVENT_HOMEASSISTANT_START) hass.bus.fire(EVENT_HOMEASSISTANT_START)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
assert len(mock_activate.mock_calls) == 1 assert len(mock_activate.mock_calls) == 1

View File

@ -1,5 +1,4 @@
"""The tests for the Graphite component.""" """The tests for the Graphite component."""
import asyncio
import socket import socket
from unittest import mock from unittest import mock
from unittest.mock import patch from unittest.mock import patch
@ -91,9 +90,11 @@ async def test_start(hass: HomeAssistant, mock_socket, mock_time) -> None:
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("test.entity", STATE_ON) hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -114,9 +115,11 @@ async def test_shutdown(hass: HomeAssistant, mock_socket, mock_time) -> None:
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("test.entity", STATE_ON) hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -134,7 +137,7 @@ async def test_shutdown(hass: HomeAssistant, mock_socket, mock_time) -> None:
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("test.entity", STATE_OFF) hass.states.async_set("test.entity", STATE_OFF)
await asyncio.sleep(0.1) await hass.async_block_till_done()
assert mock_socket.return_value.connect.call_count == 0 assert mock_socket.return_value.connect.call_count == 0
assert mock_socket.return_value.sendall.call_count == 0 assert mock_socket.return_value.sendall.call_count == 0
@ -156,9 +159,11 @@ async def test_report_attributes(hass: HomeAssistant, mock_socket, mock_time) ->
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("test.entity", STATE_ON, attrs) hass.states.async_set("test.entity", STATE_ON, attrs)
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -186,9 +191,11 @@ async def test_report_with_string_state(
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("test.entity", "above_horizon", {"foo": 1.0}) hass.states.async_set("test.entity", "above_horizon", {"foo": 1.0})
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -203,7 +210,8 @@ async def test_report_with_string_state(
mock_socket.reset_mock() mock_socket.reset_mock()
hass.states.async_set("test.entity", "not_float") hass.states.async_set("test.entity", "not_float")
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 0 assert mock_socket.return_value.connect.call_count == 0
assert mock_socket.return_value.sendall.call_count == 0 assert mock_socket.return_value.sendall.call_count == 0
@ -221,13 +229,15 @@ async def test_report_with_binary_state(
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
expected = [ expected = [
"ha.test.entity.foo 1.000000 12345", "ha.test.entity.foo 1.000000 12345",
"ha.test.entity.state 1.000000 12345", "ha.test.entity.state 1.000000 12345",
] ]
hass.states.async_set("test.entity", STATE_ON, {"foo": 1.0}) hass.states.async_set("test.entity", STATE_ON, {"foo": 1.0})
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -246,7 +256,8 @@ async def test_report_with_binary_state(
"ha.test.entity.state 0.000000 12345", "ha.test.entity.state 0.000000 12345",
] ]
hass.states.async_set("test.entity", STATE_OFF, {"foo": 1.0}) hass.states.async_set("test.entity", STATE_OFF, {"foo": 1.0})
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert mock_socket.return_value.connect.call_count == 1 assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003)) assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
@ -282,10 +293,12 @@ async def test_send_to_graphite_errors(
mock_socket.reset_mock() mock_socket.reset_mock()
await hass.async_start() await hass.async_start()
await hass.async_block_till_done()
mock_socket.return_value.connect.side_effect = error mock_socket.return_value.connect.side_effect = error
hass.states.async_set("test.entity", STATE_ON) hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1) await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()
assert log_text in caplog.text assert log_text in caplog.text

View File

@ -71,6 +71,7 @@ async def test_system_status_subscription(
): ):
freezer.tick(TEST_TIME_ADVANCE_INTERVAL) freezer.tick(TEST_TIME_ADVANCE_INTERVAL)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
response = await client.receive_json() response = await client.receive_json()
assert response["event"] == { assert response["event"] == {

View File

@ -239,7 +239,7 @@ async def test_receive_code(mock_debug, hass):
}, },
**PilightDaemonSim.test_message["message"], **PilightDaemonSim.test_message["message"],
) )
debug_log_call = mock_debug.call_args_list[-3] debug_log_call = mock_debug.call_args_list[-1]
# Check if all message parts are put on event bus # Check if all message parts are put on event bus
for key, value in expected_message.items(): for key, value in expected_message.items():
@ -272,7 +272,7 @@ async def test_whitelist_exact_match(mock_debug, hass):
}, },
**PilightDaemonSim.test_message["message"], **PilightDaemonSim.test_message["message"],
) )
debug_log_call = mock_debug.call_args_list[-3] debug_log_call = mock_debug.call_args_list[-1]
# Check if all message parts are put on event bus # Check if all message parts are put on event bus
for key, value in expected_message.items(): for key, value in expected_message.items():
@ -303,7 +303,7 @@ async def test_whitelist_partial_match(mock_debug, hass):
}, },
**PilightDaemonSim.test_message["message"], **PilightDaemonSim.test_message["message"],
) )
debug_log_call = mock_debug.call_args_list[-3] debug_log_call = mock_debug.call_args_list[-1]
# Check if all message parts are put on event bus # Check if all message parts are put on event bus
for key, value in expected_message.items(): for key, value in expected_message.items():
@ -337,7 +337,7 @@ async def test_whitelist_or_match(mock_debug, hass):
}, },
**PilightDaemonSim.test_message["message"], **PilightDaemonSim.test_message["message"],
) )
debug_log_call = mock_debug.call_args_list[-3] debug_log_call = mock_debug.call_args_list[-1]
# Check if all message parts are put on event bus # Check if all message parts are put on event bus
for key, value in expected_message.items(): for key, value in expected_message.items():
@ -360,7 +360,7 @@ async def test_whitelist_no_match(mock_debug, hass):
await hass.async_start() await hass.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
debug_log_call = mock_debug.call_args_list[-3] debug_log_call = mock_debug.call_args_list[-1]
assert "Event pilight_received" not in debug_log_call assert "Event pilight_received" not in debug_log_call

View File

@ -43,8 +43,9 @@ async def test_basic_usage(hass, tmpdir):
return last_filename return last_filename
with patch("cProfile.Profile"), patch.object(hass.config, "path", _mock_path): with patch("cProfile.Profile"), patch.object(hass.config, "path", _mock_path):
await hass.services.async_call(DOMAIN, SERVICE_START, {CONF_SECONDS: 0.000001}) await hass.services.async_call(
await hass.async_block_till_done() DOMAIN, SERVICE_START, {CONF_SECONDS: 0.000001}, blocking=True
)
assert os.path.exists(last_filename) assert os.path.exists(last_filename)
@ -72,8 +73,9 @@ async def test_memory_usage(hass, tmpdir):
return last_filename return last_filename
with patch("guppy.hpy") as mock_hpy, patch.object(hass.config, "path", _mock_path): with patch("guppy.hpy") as mock_hpy, patch.object(hass.config, "path", _mock_path):
await hass.services.async_call(DOMAIN, SERVICE_MEMORY, {CONF_SECONDS: 0.000001}) await hass.services.async_call(
await hass.async_block_till_done() DOMAIN, SERVICE_MEMORY, {CONF_SECONDS: 0.000001}, blocking=True
)
mock_hpy.assert_called_once() mock_hpy.assert_called_once()
@ -97,9 +99,8 @@ async def test_object_growth_logging(
with patch("objgraph.growth"): with patch("objgraph.growth"):
await hass.services.async_call( await hass.services.async_call(
DOMAIN, SERVICE_START_LOG_OBJECTS, {CONF_SCAN_INTERVAL: 10} DOMAIN, SERVICE_START_LOG_OBJECTS, {CONF_SCAN_INTERVAL: 10}, blocking=True
) )
await hass.async_block_till_done()
assert "Growth" in caplog.text assert "Growth" in caplog.text
caplog.clear() caplog.clear()
@ -108,8 +109,7 @@ async def test_object_growth_logging(
await hass.async_block_till_done() await hass.async_block_till_done()
assert "Growth" in caplog.text assert "Growth" in caplog.text
await hass.services.async_call(DOMAIN, SERVICE_STOP_LOG_OBJECTS, {}) await hass.services.async_call(DOMAIN, SERVICE_STOP_LOG_OBJECTS, {}, blocking=True)
await hass.async_block_till_done()
caplog.clear() caplog.clear()
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=21)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=21))
@ -150,9 +150,8 @@ async def test_dump_log_object(
assert hass.services.has_service(DOMAIN, SERVICE_DUMP_LOG_OBJECTS) assert hass.services.has_service(DOMAIN, SERVICE_DUMP_LOG_OBJECTS)
await hass.services.async_call( await hass.services.async_call(
DOMAIN, SERVICE_DUMP_LOG_OBJECTS, {CONF_TYPE: "DumpLogDummy"} DOMAIN, SERVICE_DUMP_LOG_OBJECTS, {CONF_TYPE: "DumpLogDummy"}, blocking=True
) )
await hass.async_block_till_done()
assert "<DumpLogDummy success>" in caplog.text assert "<DumpLogDummy success>" in caplog.text
assert "Failed to serialize" in caplog.text assert "Failed to serialize" in caplog.text
@ -174,8 +173,7 @@ async def test_log_thread_frames(
assert hass.services.has_service(DOMAIN, SERVICE_LOG_THREAD_FRAMES) assert hass.services.has_service(DOMAIN, SERVICE_LOG_THREAD_FRAMES)
await hass.services.async_call(DOMAIN, SERVICE_LOG_THREAD_FRAMES, {}) await hass.services.async_call(DOMAIN, SERVICE_LOG_THREAD_FRAMES, {}, blocking=True)
await hass.async_block_till_done()
assert "SyncWorker_0" in caplog.text assert "SyncWorker_0" in caplog.text
caplog.clear() caplog.clear()
@ -197,8 +195,9 @@ async def test_log_scheduled(
assert hass.services.has_service(DOMAIN, SERVICE_LOG_EVENT_LOOP_SCHEDULED) assert hass.services.has_service(DOMAIN, SERVICE_LOG_EVENT_LOOP_SCHEDULED)
await hass.services.async_call(DOMAIN, SERVICE_LOG_EVENT_LOOP_SCHEDULED, {}) await hass.services.async_call(
await hass.async_block_till_done() DOMAIN, SERVICE_LOG_EVENT_LOOP_SCHEDULED, {}, blocking=True
)
assert "Scheduled" in caplog.text assert "Scheduled" in caplog.text
caplog.clear() caplog.clear()

View File

@ -815,15 +815,13 @@ async def test_wait_for_trigger_variables(hass: HomeAssistant) -> None:
actions = [ actions = [
{ {
"alias": "variables", "alias": "variables",
"variables": {"seconds": 5}, "variables": {"state": "off"},
}, },
{ {
"alias": wait_alias, "alias": wait_alias,
"wait_for_trigger": { "wait_for_trigger": {
"platform": "state", "platform": "template",
"entity_id": "switch.test", "value_template": "{{ states.switch.test.state == state }}",
"to": "off",
"for": {"seconds": "{{ seconds }}"},
}, },
}, },
] ]
@ -839,9 +837,6 @@ async def test_wait_for_trigger_variables(hass: HomeAssistant) -> None:
assert script_obj.is_running assert script_obj.is_running
assert script_obj.last_action == wait_alias assert script_obj.last_action == wait_alias
hass.states.async_set("switch.test", "off") hass.states.async_set("switch.test", "off")
# the script task + 2 tasks created by wait_for_trigger script step
await hass.async_wait_for_task_count(3)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
await hass.async_block_till_done() await hass.async_block_till_done()
except (AssertionError, asyncio.TimeoutError): except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop() await script_obj.async_stop()

View File

@ -774,12 +774,10 @@ async def test_warning_logged_on_wrap_up_timeout(hass, caplog):
def gen_domain_setup(domain): def gen_domain_setup(domain):
async def async_setup(hass, config): async def async_setup(hass, config):
await asyncio.sleep(0.1)
async def _background_task(): async def _background_task():
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
await hass.async_create_task(_background_task()) hass.async_create_task(_background_task())
return True return True
return async_setup return async_setup

View File

@ -3559,14 +3559,17 @@ async def test_initializing_flows_canceled_on_shutdown(hass: HomeAssistant, mana
"""Mock Reauth.""" """Mock Reauth."""
await asyncio.sleep(1) await asyncio.sleep(1)
mock_integration(hass, MockModule("test"))
mock_entity_platform(hass, "config_flow.test", None)
with patch.dict( with patch.dict(
config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler} config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler}
): ):
task = asyncio.create_task( task = asyncio.create_task(
manager.flow.async_init("test", context={"source": "reauth"}) manager.flow.async_init("test", context={"source": "reauth"})
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await manager.flow.async_shutdown() await manager.flow.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError): with pytest.raises(asyncio.exceptions.CancelledError):
await task await task

View File

@ -124,7 +124,7 @@ def test_async_add_job_add_hass_threaded_job_to_pool() -> None:
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job)) ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job))
assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.loop.run_in_executor.mock_calls) == 1 assert len(hass.loop.run_in_executor.mock_calls) == 2
def test_async_create_task_schedule_coroutine(event_loop): def test_async_create_task_schedule_coroutine(event_loop):
@ -205,7 +205,7 @@ async def test_shutdown_calls_block_till_done_after_shutdown_run_callback_thread
assert stop_calls[-1] == "async_block_till_done" assert stop_calls[-1] == "async_block_till_done"
async def test_pending_sheduler(hass: HomeAssistant) -> None: async def test_pending_scheduler(hass: HomeAssistant) -> None:
"""Add a coro to pending tasks.""" """Add a coro to pending tasks."""
call_count = [] call_count = []
@ -216,9 +216,9 @@ async def test_pending_sheduler(hass: HomeAssistant) -> None:
for _ in range(3): for _ in range(3):
hass.async_add_job(test_coro()) hass.async_add_job(test_coro())
await asyncio.wait(hass._pending_tasks) await asyncio.wait(hass._tasks)
assert len(hass._pending_tasks) == 3 assert len(hass._tasks) == 0
assert len(call_count) == 3 assert len(call_count) == 3
@ -240,7 +240,7 @@ async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
await wait_finish_callback() await wait_finish_callback()
assert len(hass._pending_tasks) == 2 assert len(hass._tasks) == 2
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(call_count) == 2 assert len(call_count) == 2
@ -263,7 +263,7 @@ async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None
await wait_finish_callback() await wait_finish_callback()
assert len(hass._pending_tasks) == 2 assert len(hass._tasks) == 2
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(call_count) == 2 assert len(call_count) == 2
@ -286,7 +286,6 @@ async def test_async_add_job_pending_tasks_executor(hass: HomeAssistant) -> None
await wait_finish_callback() await wait_finish_callback()
assert len(hass._pending_tasks) == 2
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(call_count) == 2 assert len(call_count) == 2
@ -312,7 +311,7 @@ async def test_async_add_job_pending_tasks_callback(hass: HomeAssistant) -> None
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass._pending_tasks) == 0 assert len(hass._tasks) == 0
assert len(call_count) == 2 assert len(call_count) == 2
@ -1144,11 +1143,10 @@ async def test_start_taking_too_long(event_loop, caplog):
"""Test when async_start takes too long.""" """Test when async_start takes too long."""
hass = ha.HomeAssistant() hass = ha.HomeAssistant()
caplog.set_level(logging.WARNING) caplog.set_level(logging.WARNING)
hass.async_create_task(asyncio.sleep(0))
try: try:
with patch.object( with patch("asyncio.wait", return_value=(set(), {asyncio.Future()})):
hass, "async_block_till_done", side_effect=asyncio.TimeoutError
):
await hass.async_start() await hass.async_start()
assert hass.state == ha.CoreState.running assert hass.state == ha.CoreState.running
@ -1159,21 +1157,6 @@ async def test_start_taking_too_long(event_loop, caplog):
assert hass.state == ha.CoreState.stopped assert hass.state == ha.CoreState.stopped
async def test_track_task_functions(event_loop):
"""Test function to start/stop track task and initial state."""
hass = ha.HomeAssistant()
try:
assert hass._track_task
hass.async_stop_track_tasks()
assert not hass._track_task
hass.async_track_tasks()
assert hass._track_task
finally:
await hass.async_stop()
async def test_service_executed_with_subservices(hass: HomeAssistant) -> None: async def test_service_executed_with_subservices(hass: HomeAssistant) -> None:
"""Test we block correctly till all services done.""" """Test we block correctly till all services done."""
calls = async_mock_service(hass, "test", "inner") calls = async_mock_service(hass, "test", "inner")