Reduce script overhead by avoiding creation of many tasks (#113183)

* Reduce script overhead by avoiding creation of many tasks

* no eager stop

* reduce

* make sure wait being cancelled is handled

* make sure wait being cancelled is handled

* make sure wait being cancelled is handled

* preen

* preen

* result already raises cancelled error, remove redundant code

* no need to raise it into the future

* will never set an exception

* Simplify long action script implementation

* comment

* preen

* dry

* dry

* preen

* dry

* preen

* no need to access protected

* no need to access protected

* dry

* name

* dry

* dry

* dry

* dry

* reduce name changes

* drop one more task

* stale comment

* stale comment
This commit is contained in:
J. Nick Koston 2024-03-14 14:28:27 -10:00 committed by GitHub
parent e293afe46e
commit 09934d44c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager
from contextvars import ContextVar from contextvars import ContextVar
from copy import copy from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
@ -15,6 +15,7 @@ import logging
from types import MappingProxyType from types import MappingProxyType
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast
import async_interrupt
import voluptuous as vol import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
@ -157,6 +158,16 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
class ScriptStoppedError(Exception):
"""Error to indicate that the script has been stopped."""
def _set_result_unless_done(future: asyncio.Future[None]) -> None:
"""Set result of future unless it is done."""
if not future.done():
future.set_result(None)
def action_trace_append(variables, path): def action_trace_append(variables, path):
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables, path) trace_element = TraceElement(variables, path)
@ -168,7 +179,7 @@ def action_trace_append(variables, path):
async def trace_action( async def trace_action(
hass: HomeAssistant, hass: HomeAssistant,
script_run: _ScriptRun, script_run: _ScriptRun,
stop: asyncio.Event, stop: asyncio.Future[None],
variables: dict[str, Any], variables: dict[str, Any],
) -> AsyncGenerator[TraceElement, None]: ) -> AsyncGenerator[TraceElement, None]:
"""Trace action execution.""" """Trace action execution."""
@ -199,13 +210,13 @@ async def trace_action(
): ):
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path) async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path)
done = asyncio.Event() done = hass.loop.create_future()
@callback @callback
def async_continue_stop(command=None): def async_continue_stop(command=None):
if command == "stop": if command == "stop":
stop.set() _set_result_unless_done(stop)
done.set() _set_result_unless_done(done)
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id)
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop) remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
@ -213,10 +224,7 @@ async def trace_action(
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
) )
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)] await asyncio.wait([stop, done], return_when=asyncio.FIRST_COMPLETED)
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in tasks:
task.cancel()
remove_signal1() remove_signal1()
remove_signal2() remove_signal2()
@ -393,12 +401,12 @@ class _ScriptRun:
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self._step = -1 self._step = -1
self._started = False self._started = False
self._stop = asyncio.Event() self._stop = hass.loop.create_future()
self._stopped = asyncio.Event() self._stopped = asyncio.Event()
self._conversation_response: str | None | UndefinedType = UNDEFINED self._conversation_response: str | None | UndefinedType = UNDEFINED
def _changed(self) -> None: def _changed(self) -> None:
if not self._stop.is_set(): if not self._stop.done():
self._script._changed() # pylint: disable=protected-access self._script._changed() # pylint: disable=protected-access
async def _async_get_condition(self, config): async def _async_get_condition(self, config):
@ -432,7 +440,7 @@ class _ScriptRun:
try: try:
self._log("Running %s", self._script.running_description) self._log("Running %s", self._script.running_description)
for self._step, self._action in enumerate(self._script.sequence): for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set(): if self._stop.done():
script_execution_set("cancelled") script_execution_set("cancelled")
break break
await self._async_step(log_exceptions=False) await self._async_step(log_exceptions=False)
@ -471,7 +479,7 @@ class _ScriptRun:
async with trace_action( async with trace_action(
self._hass, self, self._stop, self._variables self._hass, self, self._stop, self._variables
) as trace_element: ) as trace_element:
if self._stop.is_set(): if self._stop.done():
return return
action = cv.determine_script_action(self._action) action = cv.determine_script_action(self._action)
@ -483,8 +491,8 @@ class _ScriptRun:
trace_set_result(enabled=False) trace_set_result(enabled=False)
return return
handler = f"_async_{action}_step"
try: try:
handler = f"_async_{action}_step"
await getattr(self, handler)() await getattr(self, handler)()
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
self._handle_exception( self._handle_exception(
@ -502,7 +510,7 @@ class _ScriptRun:
async def async_stop(self) -> None: async def async_stop(self) -> None:
"""Stop script run.""" """Stop script run."""
self._stop.set() _set_result_unless_done(self._stop)
# If the script was never started # If the script was never started
# the stopped event will never be # the stopped event will never be
# set because the script will never # set because the script will never
@ -576,9 +584,9 @@ class _ScriptRun:
level=level, level=level,
) )
def _get_pos_time_period_template(self, key): def _get_pos_time_period_template(self, key: str) -> timedelta:
try: try:
return cv.positive_time_period( return cv.positive_time_period( # type: ignore[no-any-return]
template.render_complex(self._action[key], self._variables) template.render_complex(self._action[key], self._variables)
) )
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
@ -593,26 +601,34 @@ class _ScriptRun:
async def _async_delay_step(self): async def _async_delay_step(self):
"""Handle delay.""" """Handle delay."""
delay = self._get_pos_time_period_template(CONF_DELAY) delay_delta = self._get_pos_time_period_template(CONF_DELAY)
self._step_log(f"delay {delay}") self._step_log(f"delay {delay_delta}")
delay = delay.total_seconds() delay = delay_delta.total_seconds()
self._changed() self._changed()
trace_set_result(delay=delay, done=False) trace_set_result(delay=delay, done=False)
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
delay
)
try: try:
async with asyncio.timeout(delay): await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
await self._stop.wait() finally:
except TimeoutError: if timeout_future.done():
trace_set_result(delay=delay, done=True) trace_set_result(delay=delay, done=True)
else:
timeout_handle.cancel()
def _get_timeout_seconds_from_action(self) -> float | None:
"""Get the timeout from the action."""
if CONF_TIMEOUT in self._action:
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
return None
async def _async_wait_template_step(self): async def _async_wait_template_step(self):
"""Handle a wait template.""" """Handle a wait template."""
if CONF_TIMEOUT in self._action: timeout = self._get_timeout_seconds_from_action()
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
timeout = None
self._step_log("wait template", timeout) self._step_log("wait template", timeout)
self._variables["wait"] = {"remaining": timeout, "completed": False} self._variables["wait"] = {"remaining": timeout, "completed": False}
@ -626,74 +642,47 @@ class _ScriptRun:
self._variables["wait"]["completed"] = True self._variables["wait"]["completed"] = True
return return
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
timeout
)
done = self._hass.loop.create_future()
futures.append(done)
@callback @callback
def async_script_wait(entity_id, from_s, to_s): def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true.""" """Handle script after template condition is true."""
# pylint: disable=protected-access self._async_set_remaining_time_var(timeout_handle)
wait_var = self._variables["wait"] self._variables["wait"]["completed"] = True
if to_context and to_context._when: _set_result_unless_done(done)
wait_var["remaining"] = to_context._when - self._hass.loop.time()
else:
wait_var["remaining"] = timeout
wait_var["completed"] = True
done.set()
to_context = None
unsub = async_track_template( unsub = async_track_template(
self._hass, wait_template, async_script_wait, self._variables self._hass, wait_template, async_script_wait, self._variables
) )
self._changed() self._changed()
done = asyncio.Event() await self._async_wait_with_optional_timeout(
tasks = [ futures, timeout_handle, timeout_future, unsub
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) )
]
try: def _async_set_remaining_time_var(
async with asyncio.timeout(timeout) as to_context: self, timeout_handle: asyncio.TimerHandle | None
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) ) -> None:
except TimeoutError as ex: """Set the remaining time variable for a wait step."""
self._variables["wait"]["remaining"] = 0.0 wait_var = self._variables["wait"]
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): if timeout_handle:
self._log(_TIMEOUT_MSG) wait_var["remaining"] = timeout_handle.when() - self._hass.loop.time()
trace_set_result(wait=self._variables["wait"], timeout=True) else:
raise _AbortScript from ex wait_var["remaining"] = None
finally:
for task in tasks:
task.cancel()
unsub()
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None: async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
"""Run a long task while monitoring for stop request.""" """Run a long task while monitoring for stop request."""
async def async_cancel_long_task() -> None:
# Stop long task and wait for it to finish.
long_task.cancel()
with suppress(Exception):
await long_task
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try: try:
await asyncio.wait( async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED # if stop is set, interrupt will cancel inside the context
) # manager which will cancel long_task, and raise
# If our task is cancelled, then cancel long task, too. Note that if long task # ScriptStoppedError outside the context manager
# is cancelled otherwise the CancelledError exception will not be raised to return await long_task
# here due to the call to asyncio.wait(). Rather we'll check for that below. except ScriptStoppedError as ex:
except asyncio.CancelledError: raise asyncio.CancelledError from ex
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
return long_task.result()
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
return None
async def _async_call_service_step(self): async def _async_call_service_step(self):
"""Call the service specified in the action.""" """Call the service specified in the action."""
@ -735,8 +724,9 @@ class _ScriptRun:
blocking=True, blocking=True,
context=self._context, context=self._context,
return_response=return_response, return_response=return_response,
) ),
), eager_start=True,
)
) )
if response_variable: if response_variable:
self._variables[response_variable] = response_data self._variables[response_variable] = response_data
@ -866,7 +856,7 @@ class _ScriptRun:
for iteration in range(1, count + 1): for iteration in range(1, count + 1):
set_repeat_var(iteration, count) set_repeat_var(iteration, count)
await async_run_sequence(iteration, extra_msg) await async_run_sequence(iteration, extra_msg)
if self._stop.is_set(): if self._stop.done():
break break
elif CONF_FOR_EACH in repeat: elif CONF_FOR_EACH in repeat:
@ -894,7 +884,7 @@ class _ScriptRun:
for iteration, item in enumerate(items, 1): for iteration, item in enumerate(items, 1):
set_repeat_var(iteration, count, item) set_repeat_var(iteration, count, item)
extra_msg = f" of {count} with item: {repr(item)}" extra_msg = f" of {count} with item: {repr(item)}"
if self._stop.is_set(): if self._stop.done():
break break
await async_run_sequence(iteration, extra_msg) await async_run_sequence(iteration, extra_msg)
@ -905,7 +895,7 @@ class _ScriptRun:
for iteration in itertools.count(1): for iteration in itertools.count(1):
set_repeat_var(iteration) set_repeat_var(iteration)
try: try:
if self._stop.is_set(): if self._stop.done():
break break
if not self._test_conditions(conditions, "while"): if not self._test_conditions(conditions, "while"):
break break
@ -923,7 +913,7 @@ class _ScriptRun:
set_repeat_var(iteration) set_repeat_var(iteration)
await async_run_sequence(iteration) await async_run_sequence(iteration)
try: try:
if self._stop.is_set(): if self._stop.done():
break break
if self._test_conditions(conditions, "until") in [True, None]: if self._test_conditions(conditions, "until") in [True, None]:
break break
@ -983,12 +973,35 @@ class _ScriptRun:
with trace_path("else"): with trace_path("else"):
await self._async_run_script(if_data["if_else"]) await self._async_run_script(if_data["if_else"])
def _async_futures_with_timeout(
self,
timeout: float | None,
) -> tuple[
list[asyncio.Future[None]],
asyncio.TimerHandle | None,
asyncio.Future[None] | None,
]:
"""Return a list of futures to wait for.
The list will contain the stop future.
If timeout is set, a timeout future and handle will be created
and will be added to the list of futures.
"""
timeout_handle: asyncio.TimerHandle | None = None
timeout_future: asyncio.Future[None] | None = None
futures: list[asyncio.Future[None]] = [self._stop]
if timeout:
timeout_future = self._hass.loop.create_future()
timeout_handle = self._hass.loop.call_later(
timeout, _set_result_unless_done, timeout_future
)
futures.append(timeout_future)
return futures, timeout_handle, timeout_future
async def _async_wait_for_trigger_step(self): async def _async_wait_for_trigger_step(self):
"""Wait for a trigger event.""" """Wait for a trigger event."""
if CONF_TIMEOUT in self._action: timeout = self._get_timeout_seconds_from_action()
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
timeout = None
self._step_log("wait for trigger", timeout) self._step_log("wait for trigger", timeout)
@ -996,22 +1009,20 @@ class _ScriptRun:
self._variables["wait"] = {"remaining": timeout, "trigger": None} self._variables["wait"] = {"remaining": timeout, "trigger": None}
trace_set_result(wait=self._variables["wait"]) trace_set_result(wait=self._variables["wait"])
done = asyncio.Event() futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
timeout
)
done = self._hass.loop.create_future()
futures.append(done)
async def async_done(variables, context=None): async def async_done(variables, context=None):
# pylint: disable=protected-access self._async_set_remaining_time_var(timeout_handle)
wait_var = self._variables["wait"] self._variables["wait"]["trigger"] = variables["trigger"]
if to_context and to_context._when: _set_result_unless_done(done)
wait_var["remaining"] = to_context._when - self._hass.loop.time()
else:
wait_var["remaining"] = timeout
wait_var["trigger"] = variables["trigger"]
done.set()
def log_cb(level, msg, **kwargs): def log_cb(level, msg, **kwargs):
self._log(msg, level=level, **kwargs) self._log(msg, level=level, **kwargs)
to_context = None
remove_triggers = await async_initialize_triggers( remove_triggers = await async_initialize_triggers(
self._hass, self._hass,
self._action[CONF_WAIT_FOR_TRIGGER], self._action[CONF_WAIT_FOR_TRIGGER],
@ -1023,24 +1034,31 @@ class _ScriptRun:
) )
if not remove_triggers: if not remove_triggers:
return return
self._changed() self._changed()
tasks = [ await self._async_wait_with_optional_timeout(
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) futures, timeout_handle, timeout_future, remove_triggers
] )
async def _async_wait_with_optional_timeout(
self,
futures: list[asyncio.Future[None]],
timeout_handle: asyncio.TimerHandle | None,
timeout_future: asyncio.Future[None] | None,
unsub: Callable[[], None],
) -> None:
try: try:
async with asyncio.timeout(timeout) as to_context: await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) if timeout_future and timeout_future.done():
except TimeoutError as ex: self._variables["wait"]["remaining"] = 0.0
self._variables["wait"]["remaining"] = 0.0 if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG)
self._log(_TIMEOUT_MSG) trace_set_result(wait=self._variables["wait"], timeout=True)
trace_set_result(wait=self._variables["wait"], timeout=True) raise _AbortScript from TimeoutError()
raise _AbortScript from ex
finally: finally:
for task in tasks: if timeout_future and not timeout_future.done() and timeout_handle:
task.cancel() timeout_handle.cancel()
remove_triggers()
unsub()
async def _async_variables_step(self): async def _async_variables_step(self):
"""Set a variable value.""" """Set a variable value."""
@ -1107,7 +1125,7 @@ class _ScriptRun:
"""Execute a script.""" """Execute a script."""
result = await self._async_run_long_action( result = await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task(
script.async_run(self._variables, self._context) script.async_run(self._variables, self._context), eager_start=True
) )
) )
if result and result.conversation_response is not UNDEFINED: if result and result.conversation_response is not UNDEFINED:
@ -1123,29 +1141,17 @@ class _QueuedScriptRun(_ScriptRun):
"""Run script.""" """Run script."""
# Wait for previous run, if any, to finish by attempting to acquire the script's # Wait for previous run, if any, to finish by attempting to acquire the script's
# shared lock. At the same time monitor if we've been told to stop. # shared lock. At the same time monitor if we've been told to stop.
lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access
)
stop_task = self._hass.async_create_task(self._stop.wait())
try: try:
await asyncio.wait( async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
{lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED await self._script._queue_lck.acquire() # pylint: disable=protected-access
) except ScriptStoppedError as ex:
except asyncio.CancelledError: # If we've been told to stop, then just finish up.
self._finish() self._finish()
raise raise asyncio.CancelledError from ex
else:
self.lock_acquired = lock_task.done() and not lock_task.cancelled()
finally:
lock_task.cancel()
stop_task.cancel()
# If we've been told to stop, then just finish up. Otherwise, we've acquired the self.lock_acquired = True
# lock so we can go ahead and start the run. # We've acquired the lock so we can go ahead and start the run.
if self._stop.is_set(): await super().async_run()
self._finish()
else:
await super().async_run()
def _finish(self) -> None: def _finish(self) -> None:
if self.lock_acquired: if self.lock_acquired: