mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Reduce script overhead by avoiding creation of many tasks (#113183)
* Reduce script overhead by avoiding creation of many tasks * no eager stop * reduce * make sure wait being cancelled is handled * make sure wait being cancelled is handled * make sure wait being cancelled is handled * preen * preen * result already raises cancelled error, remove redundant code * no need to raise it into the future * will never set an exception * Simplify long action script implementation * comment * preen * dry * dry * preen * dry * preen * no need to access protected * no need to access protected * dry * name * dry * dry * dry * dry * reduce name changes * drop one more task * stale comment * stale comment
This commit is contained in:
parent
e293afe46e
commit
09934d44c4
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
||||||
from contextlib import asynccontextmanager, suppress
|
from contextlib import asynccontextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -15,6 +15,7 @@ import logging
|
|||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
|
import async_interrupt
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import exceptions
|
from homeassistant import exceptions
|
||||||
@ -157,6 +158,16 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
|||||||
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptStoppedError(Exception):
|
||||||
|
"""Error to indicate that the script has been stopped."""
|
||||||
|
|
||||||
|
|
||||||
|
def _set_result_unless_done(future: asyncio.Future[None]) -> None:
|
||||||
|
"""Set result of future unless it is done."""
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
def action_trace_append(variables, path):
|
def action_trace_append(variables, path):
|
||||||
"""Append a TraceElement to trace[path]."""
|
"""Append a TraceElement to trace[path]."""
|
||||||
trace_element = TraceElement(variables, path)
|
trace_element = TraceElement(variables, path)
|
||||||
@ -168,7 +179,7 @@ def action_trace_append(variables, path):
|
|||||||
async def trace_action(
|
async def trace_action(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
script_run: _ScriptRun,
|
script_run: _ScriptRun,
|
||||||
stop: asyncio.Event,
|
stop: asyncio.Future[None],
|
||||||
variables: dict[str, Any],
|
variables: dict[str, Any],
|
||||||
) -> AsyncGenerator[TraceElement, None]:
|
) -> AsyncGenerator[TraceElement, None]:
|
||||||
"""Trace action execution."""
|
"""Trace action execution."""
|
||||||
@ -199,13 +210,13 @@ async def trace_action(
|
|||||||
):
|
):
|
||||||
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path)
|
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path)
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = hass.loop.create_future()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_continue_stop(command=None):
|
def async_continue_stop(command=None):
|
||||||
if command == "stop":
|
if command == "stop":
|
||||||
stop.set()
|
_set_result_unless_done(stop)
|
||||||
done.set()
|
_set_result_unless_done(done)
|
||||||
|
|
||||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id)
|
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id)
|
||||||
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
|
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
|
||||||
@ -213,10 +224,7 @@ async def trace_action(
|
|||||||
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
|
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
|
||||||
)
|
)
|
||||||
|
|
||||||
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)]
|
await asyncio.wait([stop, done], return_when=asyncio.FIRST_COMPLETED)
|
||||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
remove_signal1()
|
remove_signal1()
|
||||||
remove_signal2()
|
remove_signal2()
|
||||||
|
|
||||||
@ -393,12 +401,12 @@ class _ScriptRun:
|
|||||||
self._log_exceptions = log_exceptions
|
self._log_exceptions = log_exceptions
|
||||||
self._step = -1
|
self._step = -1
|
||||||
self._started = False
|
self._started = False
|
||||||
self._stop = asyncio.Event()
|
self._stop = hass.loop.create_future()
|
||||||
self._stopped = asyncio.Event()
|
self._stopped = asyncio.Event()
|
||||||
self._conversation_response: str | None | UndefinedType = UNDEFINED
|
self._conversation_response: str | None | UndefinedType = UNDEFINED
|
||||||
|
|
||||||
def _changed(self) -> None:
|
def _changed(self) -> None:
|
||||||
if not self._stop.is_set():
|
if not self._stop.done():
|
||||||
self._script._changed() # pylint: disable=protected-access
|
self._script._changed() # pylint: disable=protected-access
|
||||||
|
|
||||||
async def _async_get_condition(self, config):
|
async def _async_get_condition(self, config):
|
||||||
@ -432,7 +440,7 @@ class _ScriptRun:
|
|||||||
try:
|
try:
|
||||||
self._log("Running %s", self._script.running_description)
|
self._log("Running %s", self._script.running_description)
|
||||||
for self._step, self._action in enumerate(self._script.sequence):
|
for self._step, self._action in enumerate(self._script.sequence):
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
script_execution_set("cancelled")
|
script_execution_set("cancelled")
|
||||||
break
|
break
|
||||||
await self._async_step(log_exceptions=False)
|
await self._async_step(log_exceptions=False)
|
||||||
@ -471,7 +479,7 @@ class _ScriptRun:
|
|||||||
async with trace_action(
|
async with trace_action(
|
||||||
self._hass, self, self._stop, self._variables
|
self._hass, self, self._stop, self._variables
|
||||||
) as trace_element:
|
) as trace_element:
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
return
|
return
|
||||||
|
|
||||||
action = cv.determine_script_action(self._action)
|
action = cv.determine_script_action(self._action)
|
||||||
@ -483,8 +491,8 @@ class _ScriptRun:
|
|||||||
trace_set_result(enabled=False)
|
trace_set_result(enabled=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
|
||||||
handler = f"_async_{action}_step"
|
handler = f"_async_{action}_step"
|
||||||
|
try:
|
||||||
await getattr(self, handler)()
|
await getattr(self, handler)()
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
self._handle_exception(
|
self._handle_exception(
|
||||||
@ -502,7 +510,7 @@ class _ScriptRun:
|
|||||||
|
|
||||||
async def async_stop(self) -> None:
|
async def async_stop(self) -> None:
|
||||||
"""Stop script run."""
|
"""Stop script run."""
|
||||||
self._stop.set()
|
_set_result_unless_done(self._stop)
|
||||||
# If the script was never started
|
# If the script was never started
|
||||||
# the stopped event will never be
|
# the stopped event will never be
|
||||||
# set because the script will never
|
# set because the script will never
|
||||||
@ -576,9 +584,9 @@ class _ScriptRun:
|
|||||||
level=level,
|
level=level,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_pos_time_period_template(self, key):
|
def _get_pos_time_period_template(self, key: str) -> timedelta:
|
||||||
try:
|
try:
|
||||||
return cv.positive_time_period(
|
return cv.positive_time_period( # type: ignore[no-any-return]
|
||||||
template.render_complex(self._action[key], self._variables)
|
template.render_complex(self._action[key], self._variables)
|
||||||
)
|
)
|
||||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||||
@ -593,26 +601,34 @@ class _ScriptRun:
|
|||||||
|
|
||||||
async def _async_delay_step(self):
|
async def _async_delay_step(self):
|
||||||
"""Handle delay."""
|
"""Handle delay."""
|
||||||
delay = self._get_pos_time_period_template(CONF_DELAY)
|
delay_delta = self._get_pos_time_period_template(CONF_DELAY)
|
||||||
|
|
||||||
self._step_log(f"delay {delay}")
|
self._step_log(f"delay {delay_delta}")
|
||||||
|
|
||||||
delay = delay.total_seconds()
|
delay = delay_delta.total_seconds()
|
||||||
self._changed()
|
self._changed()
|
||||||
trace_set_result(delay=delay, done=False)
|
trace_set_result(delay=delay, done=False)
|
||||||
|
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||||
|
delay
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(delay):
|
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
|
||||||
await self._stop.wait()
|
finally:
|
||||||
except TimeoutError:
|
if timeout_future.done():
|
||||||
trace_set_result(delay=delay, done=True)
|
trace_set_result(delay=delay, done=True)
|
||||||
|
else:
|
||||||
|
timeout_handle.cancel()
|
||||||
|
|
||||||
|
def _get_timeout_seconds_from_action(self) -> float | None:
|
||||||
|
"""Get the timeout from the action."""
|
||||||
|
if CONF_TIMEOUT in self._action:
|
||||||
|
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
||||||
|
return None
|
||||||
|
|
||||||
async def _async_wait_template_step(self):
|
async def _async_wait_template_step(self):
|
||||||
"""Handle a wait template."""
|
"""Handle a wait template."""
|
||||||
if CONF_TIMEOUT in self._action:
|
timeout = self._get_timeout_seconds_from_action()
|
||||||
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
|
||||||
else:
|
|
||||||
timeout = None
|
|
||||||
|
|
||||||
self._step_log("wait template", timeout)
|
self._step_log("wait template", timeout)
|
||||||
|
|
||||||
self._variables["wait"] = {"remaining": timeout, "completed": False}
|
self._variables["wait"] = {"remaining": timeout, "completed": False}
|
||||||
@ -626,74 +642,47 @@ class _ScriptRun:
|
|||||||
self._variables["wait"]["completed"] = True
|
self._variables["wait"]["completed"] = True
|
||||||
return
|
return
|
||||||
|
|
||||||
|
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||||
|
timeout
|
||||||
|
)
|
||||||
|
done = self._hass.loop.create_future()
|
||||||
|
futures.append(done)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_script_wait(entity_id, from_s, to_s):
|
def async_script_wait(entity_id, from_s, to_s):
|
||||||
"""Handle script after template condition is true."""
|
"""Handle script after template condition is true."""
|
||||||
# pylint: disable=protected-access
|
self._async_set_remaining_time_var(timeout_handle)
|
||||||
wait_var = self._variables["wait"]
|
self._variables["wait"]["completed"] = True
|
||||||
if to_context and to_context._when:
|
_set_result_unless_done(done)
|
||||||
wait_var["remaining"] = to_context._when - self._hass.loop.time()
|
|
||||||
else:
|
|
||||||
wait_var["remaining"] = timeout
|
|
||||||
wait_var["completed"] = True
|
|
||||||
done.set()
|
|
||||||
|
|
||||||
to_context = None
|
|
||||||
unsub = async_track_template(
|
unsub = async_track_template(
|
||||||
self._hass, wait_template, async_script_wait, self._variables
|
self._hass, wait_template, async_script_wait, self._variables
|
||||||
)
|
)
|
||||||
|
|
||||||
self._changed()
|
self._changed()
|
||||||
done = asyncio.Event()
|
await self._async_wait_with_optional_timeout(
|
||||||
tasks = [
|
futures, timeout_handle, timeout_future, unsub
|
||||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
)
|
||||||
]
|
|
||||||
try:
|
def _async_set_remaining_time_var(
|
||||||
async with asyncio.timeout(timeout) as to_context:
|
self, timeout_handle: asyncio.TimerHandle | None
|
||||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
) -> None:
|
||||||
except TimeoutError as ex:
|
"""Set the remaining time variable for a wait step."""
|
||||||
self._variables["wait"]["remaining"] = 0.0
|
wait_var = self._variables["wait"]
|
||||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
if timeout_handle:
|
||||||
self._log(_TIMEOUT_MSG)
|
wait_var["remaining"] = timeout_handle.when() - self._hass.loop.time()
|
||||||
trace_set_result(wait=self._variables["wait"], timeout=True)
|
else:
|
||||||
raise _AbortScript from ex
|
wait_var["remaining"] = None
|
||||||
finally:
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
|
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
|
||||||
"""Run a long task while monitoring for stop request."""
|
"""Run a long task while monitoring for stop request."""
|
||||||
|
|
||||||
async def async_cancel_long_task() -> None:
|
|
||||||
# Stop long task and wait for it to finish.
|
|
||||||
long_task.cancel()
|
|
||||||
with suppress(Exception):
|
|
||||||
await long_task
|
|
||||||
|
|
||||||
# Wait for long task while monitoring for a stop request.
|
|
||||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait(
|
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
||||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
# if stop is set, interrupt will cancel inside the context
|
||||||
)
|
# manager which will cancel long_task, and raise
|
||||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
# ScriptStoppedError outside the context manager
|
||||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
return await long_task
|
||||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
except ScriptStoppedError as ex:
|
||||||
except asyncio.CancelledError:
|
raise asyncio.CancelledError from ex
|
||||||
await async_cancel_long_task()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
stop_task.cancel()
|
|
||||||
|
|
||||||
if long_task.cancelled():
|
|
||||||
raise asyncio.CancelledError
|
|
||||||
if long_task.done():
|
|
||||||
# Propagate any exceptions that occurred.
|
|
||||||
return long_task.result()
|
|
||||||
# Stopped before long task completed, so cancel it.
|
|
||||||
await async_cancel_long_task()
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _async_call_service_step(self):
|
async def _async_call_service_step(self):
|
||||||
"""Call the service specified in the action."""
|
"""Call the service specified in the action."""
|
||||||
@ -735,8 +724,9 @@ class _ScriptRun:
|
|||||||
blocking=True,
|
blocking=True,
|
||||||
context=self._context,
|
context=self._context,
|
||||||
return_response=return_response,
|
return_response=return_response,
|
||||||
)
|
|
||||||
),
|
),
|
||||||
|
eager_start=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if response_variable:
|
if response_variable:
|
||||||
self._variables[response_variable] = response_data
|
self._variables[response_variable] = response_data
|
||||||
@ -866,7 +856,7 @@ class _ScriptRun:
|
|||||||
for iteration in range(1, count + 1):
|
for iteration in range(1, count + 1):
|
||||||
set_repeat_var(iteration, count)
|
set_repeat_var(iteration, count)
|
||||||
await async_run_sequence(iteration, extra_msg)
|
await async_run_sequence(iteration, extra_msg)
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
break
|
break
|
||||||
|
|
||||||
elif CONF_FOR_EACH in repeat:
|
elif CONF_FOR_EACH in repeat:
|
||||||
@ -894,7 +884,7 @@ class _ScriptRun:
|
|||||||
for iteration, item in enumerate(items, 1):
|
for iteration, item in enumerate(items, 1):
|
||||||
set_repeat_var(iteration, count, item)
|
set_repeat_var(iteration, count, item)
|
||||||
extra_msg = f" of {count} with item: {repr(item)}"
|
extra_msg = f" of {count} with item: {repr(item)}"
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
break
|
break
|
||||||
await async_run_sequence(iteration, extra_msg)
|
await async_run_sequence(iteration, extra_msg)
|
||||||
|
|
||||||
@ -905,7 +895,7 @@ class _ScriptRun:
|
|||||||
for iteration in itertools.count(1):
|
for iteration in itertools.count(1):
|
||||||
set_repeat_var(iteration)
|
set_repeat_var(iteration)
|
||||||
try:
|
try:
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
break
|
break
|
||||||
if not self._test_conditions(conditions, "while"):
|
if not self._test_conditions(conditions, "while"):
|
||||||
break
|
break
|
||||||
@ -923,7 +913,7 @@ class _ScriptRun:
|
|||||||
set_repeat_var(iteration)
|
set_repeat_var(iteration)
|
||||||
await async_run_sequence(iteration)
|
await async_run_sequence(iteration)
|
||||||
try:
|
try:
|
||||||
if self._stop.is_set():
|
if self._stop.done():
|
||||||
break
|
break
|
||||||
if self._test_conditions(conditions, "until") in [True, None]:
|
if self._test_conditions(conditions, "until") in [True, None]:
|
||||||
break
|
break
|
||||||
@ -983,12 +973,35 @@ class _ScriptRun:
|
|||||||
with trace_path("else"):
|
with trace_path("else"):
|
||||||
await self._async_run_script(if_data["if_else"])
|
await self._async_run_script(if_data["if_else"])
|
||||||
|
|
||||||
|
def _async_futures_with_timeout(
|
||||||
|
self,
|
||||||
|
timeout: float | None,
|
||||||
|
) -> tuple[
|
||||||
|
list[asyncio.Future[None]],
|
||||||
|
asyncio.TimerHandle | None,
|
||||||
|
asyncio.Future[None] | None,
|
||||||
|
]:
|
||||||
|
"""Return a list of futures to wait for.
|
||||||
|
|
||||||
|
The list will contain the stop future.
|
||||||
|
|
||||||
|
If timeout is set, a timeout future and handle will be created
|
||||||
|
and will be added to the list of futures.
|
||||||
|
"""
|
||||||
|
timeout_handle: asyncio.TimerHandle | None = None
|
||||||
|
timeout_future: asyncio.Future[None] | None = None
|
||||||
|
futures: list[asyncio.Future[None]] = [self._stop]
|
||||||
|
if timeout:
|
||||||
|
timeout_future = self._hass.loop.create_future()
|
||||||
|
timeout_handle = self._hass.loop.call_later(
|
||||||
|
timeout, _set_result_unless_done, timeout_future
|
||||||
|
)
|
||||||
|
futures.append(timeout_future)
|
||||||
|
return futures, timeout_handle, timeout_future
|
||||||
|
|
||||||
async def _async_wait_for_trigger_step(self):
|
async def _async_wait_for_trigger_step(self):
|
||||||
"""Wait for a trigger event."""
|
"""Wait for a trigger event."""
|
||||||
if CONF_TIMEOUT in self._action:
|
timeout = self._get_timeout_seconds_from_action()
|
||||||
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
|
||||||
else:
|
|
||||||
timeout = None
|
|
||||||
|
|
||||||
self._step_log("wait for trigger", timeout)
|
self._step_log("wait for trigger", timeout)
|
||||||
|
|
||||||
@ -996,22 +1009,20 @@ class _ScriptRun:
|
|||||||
self._variables["wait"] = {"remaining": timeout, "trigger": None}
|
self._variables["wait"] = {"remaining": timeout, "trigger": None}
|
||||||
trace_set_result(wait=self._variables["wait"])
|
trace_set_result(wait=self._variables["wait"])
|
||||||
|
|
||||||
done = asyncio.Event()
|
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||||
|
timeout
|
||||||
|
)
|
||||||
|
done = self._hass.loop.create_future()
|
||||||
|
futures.append(done)
|
||||||
|
|
||||||
async def async_done(variables, context=None):
|
async def async_done(variables, context=None):
|
||||||
# pylint: disable=protected-access
|
self._async_set_remaining_time_var(timeout_handle)
|
||||||
wait_var = self._variables["wait"]
|
self._variables["wait"]["trigger"] = variables["trigger"]
|
||||||
if to_context and to_context._when:
|
_set_result_unless_done(done)
|
||||||
wait_var["remaining"] = to_context._when - self._hass.loop.time()
|
|
||||||
else:
|
|
||||||
wait_var["remaining"] = timeout
|
|
||||||
wait_var["trigger"] = variables["trigger"]
|
|
||||||
done.set()
|
|
||||||
|
|
||||||
def log_cb(level, msg, **kwargs):
|
def log_cb(level, msg, **kwargs):
|
||||||
self._log(msg, level=level, **kwargs)
|
self._log(msg, level=level, **kwargs)
|
||||||
|
|
||||||
to_context = None
|
|
||||||
remove_triggers = await async_initialize_triggers(
|
remove_triggers = await async_initialize_triggers(
|
||||||
self._hass,
|
self._hass,
|
||||||
self._action[CONF_WAIT_FOR_TRIGGER],
|
self._action[CONF_WAIT_FOR_TRIGGER],
|
||||||
@ -1023,24 +1034,31 @@ class _ScriptRun:
|
|||||||
)
|
)
|
||||||
if not remove_triggers:
|
if not remove_triggers:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._changed()
|
self._changed()
|
||||||
tasks = [
|
await self._async_wait_with_optional_timeout(
|
||||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
futures, timeout_handle, timeout_future, remove_triggers
|
||||||
]
|
)
|
||||||
|
|
||||||
|
async def _async_wait_with_optional_timeout(
|
||||||
|
self,
|
||||||
|
futures: list[asyncio.Future[None]],
|
||||||
|
timeout_handle: asyncio.TimerHandle | None,
|
||||||
|
timeout_future: asyncio.Future[None] | None,
|
||||||
|
unsub: Callable[[], None],
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(timeout) as to_context:
|
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
|
||||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
if timeout_future and timeout_future.done():
|
||||||
except TimeoutError as ex:
|
|
||||||
self._variables["wait"]["remaining"] = 0.0
|
self._variables["wait"]["remaining"] = 0.0
|
||||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||||
self._log(_TIMEOUT_MSG)
|
self._log(_TIMEOUT_MSG)
|
||||||
trace_set_result(wait=self._variables["wait"], timeout=True)
|
trace_set_result(wait=self._variables["wait"], timeout=True)
|
||||||
raise _AbortScript from ex
|
raise _AbortScript from TimeoutError()
|
||||||
finally:
|
finally:
|
||||||
for task in tasks:
|
if timeout_future and not timeout_future.done() and timeout_handle:
|
||||||
task.cancel()
|
timeout_handle.cancel()
|
||||||
remove_triggers()
|
|
||||||
|
unsub()
|
||||||
|
|
||||||
async def _async_variables_step(self):
|
async def _async_variables_step(self):
|
||||||
"""Set a variable value."""
|
"""Set a variable value."""
|
||||||
@ -1107,7 +1125,7 @@ class _ScriptRun:
|
|||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
result = await self._async_run_long_action(
|
result = await self._async_run_long_action(
|
||||||
self._hass.async_create_task(
|
self._hass.async_create_task(
|
||||||
script.async_run(self._variables, self._context)
|
script.async_run(self._variables, self._context), eager_start=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if result and result.conversation_response is not UNDEFINED:
|
if result and result.conversation_response is not UNDEFINED:
|
||||||
@ -1123,28 +1141,16 @@ class _QueuedScriptRun(_ScriptRun):
|
|||||||
"""Run script."""
|
"""Run script."""
|
||||||
# Wait for previous run, if any, to finish by attempting to acquire the script's
|
# Wait for previous run, if any, to finish by attempting to acquire the script's
|
||||||
# shared lock. At the same time monitor if we've been told to stop.
|
# shared lock. At the same time monitor if we've been told to stop.
|
||||||
lock_task = self._hass.async_create_task(
|
|
||||||
self._script._queue_lck.acquire() # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait(
|
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
||||||
{lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
await self._script._queue_lck.acquire() # pylint: disable=protected-access
|
||||||
)
|
except ScriptStoppedError as ex:
|
||||||
except asyncio.CancelledError:
|
# If we've been told to stop, then just finish up.
|
||||||
self._finish()
|
self._finish()
|
||||||
raise
|
raise asyncio.CancelledError from ex
|
||||||
else:
|
|
||||||
self.lock_acquired = lock_task.done() and not lock_task.cancelled()
|
|
||||||
finally:
|
|
||||||
lock_task.cancel()
|
|
||||||
stop_task.cancel()
|
|
||||||
|
|
||||||
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
|
self.lock_acquired = True
|
||||||
# lock so we can go ahead and start the run.
|
# We've acquired the lock so we can go ahead and start the run.
|
||||||
if self._stop.is_set():
|
|
||||||
self._finish()
|
|
||||||
else:
|
|
||||||
await super().async_run()
|
await super().async_run()
|
||||||
|
|
||||||
def _finish(self) -> None:
|
def _finish(self) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user