mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Add support for simultaneous runs of Script helper - Part 2 (#32442)
* Add limit parameter to service call methods * Break out prep part of async_call_from_config for use elsewhere * Minor cleanup * Fix improper use of asyncio.wait * Fix state update Call change listener immediately if its a callback * Fix exception handling and logging * Merge Script helper if_running/run_mode parameters into script_mode - Remove background/blocking _ScriptRun subclasses which are no longer needed. * Add queued script mode * Disable timeout when making fully blocking script call * Don't call change listener when restarting script This makes restart mode behavior consistent with parallel & queue modes. * Changes per review - Call all script services (except script.turn_off) with no time limit. - Fix handling of lock in _QueuedScriptRun and add comments to make it clearer how this code works. * Changes per review 2 - Move cancel shielding "up" from _ScriptRun.async_run to Script.async_run (and apply to new style scripts only.) This makes sure Script class also properly handles cancellation which it wasn't doing before. - In _ScriptRun._async_call_service_step, instead of using script.turn_off service, just cancel service call and let it handle the cancellation accordingly. * Fix bugs - Add missing call to change listener in Script.async_run in cancelled path. - Cancel service task if ServiceRegistry.async_call cancelled. * Revert last changes to ServiceRegistry.async_call * Minor Script helper fixes & test improvements - Don't log asyncio.CancelledError exceptions. - Make change_listener a public attribute. - Test overhaul - Parametrize tests. - Use common test functions. - Mock timeout so tests don't need to wait for real time to elapse. - Add common function for waiting for script action step.
This commit is contained in:
parent
da761fdd39
commit
5f5cb8bea8
@ -93,7 +93,7 @@ SOURCE_DISCOVERED = "discovered"
|
||||
SOURCE_STORAGE = "storage"
|
||||
SOURCE_YAML = "yaml"
|
||||
|
||||
# How long to wait till things that run on startup have to finish.
|
||||
# How long to wait until things that run on startup have to finish.
|
||||
TIMEOUT_EVENT_START = 15
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -249,7 +249,7 @@ class HomeAssistant:
|
||||
try:
|
||||
# Only block for EVENT_HOMEASSISTANT_START listener
|
||||
self.async_stop_track_tasks()
|
||||
with timeout(TIMEOUT_EVENT_START):
|
||||
async with timeout(TIMEOUT_EVENT_START):
|
||||
await self.async_block_till_done()
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.warning(
|
||||
@ -374,13 +374,13 @@ class HomeAssistant:
|
||||
self.async_add_job(target, *args)
|
||||
|
||||
def block_till_done(self) -> None:
|
||||
"""Block till all pending work is done."""
|
||||
"""Block until all pending work is done."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.async_block_till_done(), self.loop
|
||||
).result()
|
||||
|
||||
async def async_block_till_done(self) -> None:
|
||||
"""Block till all pending work is done."""
|
||||
"""Block until all pending work is done."""
|
||||
# To flush out any call_soon_threadsafe
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@ -1150,25 +1150,15 @@ class ServiceRegistry:
|
||||
service_data: Optional[Dict] = None,
|
||||
blocking: bool = False,
|
||||
context: Optional[Context] = None,
|
||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
||||
Specify blocking=True to wait till service is executed.
|
||||
Waits a maximum of SERVICE_CALL_LIMIT.
|
||||
|
||||
If blocking = True, will return boolean if service executed
|
||||
successfully within SERVICE_CALL_LIMIT.
|
||||
|
||||
This method will fire an event to call the service.
|
||||
This event will be picked up by this ServiceRegistry and any
|
||||
other ServiceRegistry that is listening on the EventBus.
|
||||
|
||||
Because the service is sent as an event you are not allowed to use
|
||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||
See description of async_call for details.
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.async_call(domain, service, service_data, blocking, context),
|
||||
self.async_call(domain, service, service_data, blocking, context, limit),
|
||||
self._hass.loop,
|
||||
).result()
|
||||
|
||||
@ -1179,19 +1169,18 @@ class ServiceRegistry:
|
||||
service_data: Optional[Dict] = None,
|
||||
blocking: bool = False,
|
||||
context: Optional[Context] = None,
|
||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
||||
Specify blocking=True to wait till service is executed.
|
||||
Waits a maximum of SERVICE_CALL_LIMIT.
|
||||
Specify blocking=True to wait until service is executed.
|
||||
Waits a maximum of limit, which may be None for no timeout.
|
||||
|
||||
If blocking = True, will return boolean if service executed
|
||||
successfully within SERVICE_CALL_LIMIT.
|
||||
successfully within limit.
|
||||
|
||||
This method will fire an event to call the service.
|
||||
This event will be picked up by this ServiceRegistry and any
|
||||
other ServiceRegistry that is listening on the EventBus.
|
||||
This method will fire an event to indicate the service has been called.
|
||||
|
||||
Because the service is sent as an event you are not allowed to use
|
||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||
@ -1230,7 +1219,7 @@ class ServiceRegistry:
|
||||
return None
|
||||
|
||||
try:
|
||||
with timeout(SERVICE_CALL_LIMIT):
|
||||
async with timeout(limit):
|
||||
await asyncio.shield(self._execute_service(handler, service_call))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
|
@ -7,6 +7,7 @@ from itertools import islice
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
|
||||
from async_timeout import timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
@ -14,6 +15,7 @@ import homeassistant.components.device_automation as device_automation
|
||||
import homeassistant.components.scene as scene
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_ALIAS,
|
||||
CONF_CONDITION,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_DELAY,
|
||||
@ -25,47 +27,53 @@ from homeassistant.const import (
|
||||
CONF_SCENE,
|
||||
CONF_TIMEOUT,
|
||||
CONF_WAIT_TEMPLATE,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
SERVICE_CALL_LIMIT,
|
||||
Context,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
is_callback,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
condition,
|
||||
config_validation as cv,
|
||||
service,
|
||||
template as template,
|
||||
)
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_point_in_utc_time,
|
||||
async_track_template,
|
||||
)
|
||||
from homeassistant.helpers.service import (
|
||||
CONF_SERVICE_DATA,
|
||||
async_prepare_call_from_config,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
CONF_ALIAS = "alias"
|
||||
|
||||
IF_RUNNING_ERROR = "error"
|
||||
IF_RUNNING_IGNORE = "ignore"
|
||||
IF_RUNNING_PARALLEL = "parallel"
|
||||
IF_RUNNING_RESTART = "restart"
|
||||
# First choice is default
|
||||
IF_RUNNING_CHOICES = [
|
||||
IF_RUNNING_PARALLEL,
|
||||
IF_RUNNING_ERROR,
|
||||
IF_RUNNING_IGNORE,
|
||||
IF_RUNNING_RESTART,
|
||||
SCRIPT_MODE_ERROR = "error"
|
||||
SCRIPT_MODE_IGNORE = "ignore"
|
||||
SCRIPT_MODE_LEGACY = "legacy"
|
||||
SCRIPT_MODE_PARALLEL = "parallel"
|
||||
SCRIPT_MODE_QUEUE = "queue"
|
||||
SCRIPT_MODE_RESTART = "restart"
|
||||
SCRIPT_MODE_CHOICES = [
|
||||
SCRIPT_MODE_ERROR,
|
||||
SCRIPT_MODE_IGNORE,
|
||||
SCRIPT_MODE_LEGACY,
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUE,
|
||||
SCRIPT_MODE_RESTART,
|
||||
]
|
||||
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
|
||||
|
||||
RUN_MODE_BACKGROUND = "background"
|
||||
RUN_MODE_BLOCKING = "blocking"
|
||||
RUN_MODE_LEGACY = "legacy"
|
||||
# First choice is default
|
||||
RUN_MODE_CHOICES = [
|
||||
RUN_MODE_BLOCKING,
|
||||
RUN_MODE_BACKGROUND,
|
||||
RUN_MODE_LEGACY,
|
||||
]
|
||||
DEFAULT_QUEUE_MAX = 10
|
||||
|
||||
_LOG_EXCEPTION = logging.ERROR + 1
|
||||
_TIMEOUT_MSG = "Timeout reached, abort script."
|
||||
@ -102,6 +110,14 @@ class _SuspendScript(Exception):
|
||||
"""Throw if script needs to suspend."""
|
||||
|
||||
|
||||
class AlreadyRunning(exceptions.HomeAssistantError):
|
||||
"""Throw if script already running and user wants error."""
|
||||
|
||||
|
||||
class QueueFull(exceptions.HomeAssistantError):
|
||||
"""Throw if script already running, user wants new run queued, but queue is full."""
|
||||
|
||||
|
||||
class _ScriptRunBase(ABC):
|
||||
"""Common data & methods for managing Script sequence run."""
|
||||
|
||||
@ -137,11 +153,11 @@ class _ScriptRunBase(ABC):
|
||||
await getattr(
|
||||
self, f"_async_{cv.determine_script_action(self._action)}_step"
|
||||
)()
|
||||
except Exception as err:
|
||||
if not isinstance(err, (_SuspendScript, _StopScript)) and (
|
||||
self._log_exceptions or log_exceptions
|
||||
):
|
||||
self._log_exception(err)
|
||||
except Exception as ex:
|
||||
if not isinstance(
|
||||
ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
|
||||
) and (self._log_exceptions or log_exceptions):
|
||||
self._log_exception(ex)
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
@ -166,6 +182,12 @@ class _ScriptRunBase(ABC):
|
||||
elif isinstance(exception, exceptions.ServiceNotFound):
|
||||
error_desc = "Service not found"
|
||||
|
||||
elif isinstance(exception, AlreadyRunning):
|
||||
error_desc = "Already running"
|
||||
|
||||
elif isinstance(exception, QueueFull):
|
||||
error_desc = "Run queue is full"
|
||||
|
||||
else:
|
||||
error_desc = "Unexpected error"
|
||||
level = _LOG_EXCEPTION
|
||||
@ -189,12 +211,13 @@ class _ScriptRunBase(ABC):
|
||||
template.render_complex(self._action[CONF_DELAY], self._variables)
|
||||
)
|
||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||
self._raise(
|
||||
self._log(
|
||||
"Error rendering %s delay template: %s",
|
||||
self._script.name,
|
||||
ex,
|
||||
exception=_StopScript,
|
||||
level=logging.ERROR,
|
||||
)
|
||||
raise _StopScript
|
||||
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
@ -220,18 +243,14 @@ class _ScriptRunBase(ABC):
|
||||
self._hass, wait_template, async_script_wait, self._variables
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
|
||||
def _prep_call_service_step(self):
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
await service.async_call_from_config(
|
||||
self._hass,
|
||||
self._action,
|
||||
blocking=True,
|
||||
variables=self._variables,
|
||||
validate_config=False,
|
||||
context=self._context,
|
||||
)
|
||||
return async_prepare_call_from_config(self._hass, self._action, self._variables)
|
||||
|
||||
async def _async_device_step(self):
|
||||
"""Perform the device automation specified in the action."""
|
||||
@ -298,10 +317,6 @@ class _ScriptRunBase(ABC):
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
def _raise(self, msg, *args, exception=None):
|
||||
# pylint: disable=protected-access
|
||||
self._script._raise(msg, *args, exception=exception)
|
||||
|
||||
|
||||
class _ScriptRun(_ScriptRunBase):
|
||||
"""Manage Script sequence run."""
|
||||
@ -318,24 +333,33 @@ class _ScriptRun(_ScriptRunBase):
|
||||
self._stop = asyncio.Event()
|
||||
self._stopped = asyncio.Event()
|
||||
|
||||
async def _async_run(self, propagate_exceptions=True):
|
||||
self._log("Running script")
|
||||
def _changed(self):
|
||||
if not self._stop.is_set():
|
||||
super()._changed()
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
try:
|
||||
if self._stop.is_set():
|
||||
return
|
||||
self._script.last_triggered = utcnow()
|
||||
self._changed()
|
||||
self._log("Running script")
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
if self._stop.is_set():
|
||||
break
|
||||
await self._async_step(not propagate_exceptions)
|
||||
await self._async_step(log_exceptions=False)
|
||||
except _StopScript:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
if propagate_exceptions:
|
||||
raise
|
||||
finally:
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
self._finish()
|
||||
|
||||
def _finish(self):
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
if not self._script.is_running:
|
||||
self._script.last_action = None
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
self._stopped.set()
|
||||
self._changed()
|
||||
self._stopped.set()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
@ -344,10 +368,13 @@ class _ScriptRun(_ScriptRunBase):
|
||||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
timeout = self._prep_delay_step().total_seconds()
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
await asyncio.wait({self._stop.wait()}, timeout=timeout)
|
||||
delay = self._prep_delay_step().total_seconds()
|
||||
self._changed()
|
||||
try:
|
||||
async with timeout(delay):
|
||||
await self._stop.wait()
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
@ -361,21 +388,20 @@ class _ScriptRun(_ScriptRunBase):
|
||||
if not unsub:
|
||||
return
|
||||
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
self._changed()
|
||||
try:
|
||||
timeout = self._action[CONF_TIMEOUT].total_seconds()
|
||||
delay = self._action[CONF_TIMEOUT].total_seconds()
|
||||
except KeyError:
|
||||
timeout = None
|
||||
delay = None
|
||||
done = asyncio.Event()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.wait(
|
||||
async with timeout(delay):
|
||||
_, pending = await asyncio.wait(
|
||||
{self._stop.wait(), done.wait()},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
for pending_task in pending:
|
||||
pending_task.cancel()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
@ -383,25 +409,78 @@ class _ScriptRun(_ScriptRunBase):
|
||||
finally:
|
||||
unsub()
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
domain, service, service_data = self._prep_call_service_step()
|
||||
|
||||
class _BackgroundScriptRun(_ScriptRun):
|
||||
"""Manage background Script sequence run."""
|
||||
# If this might start a script then disable the call timeout.
|
||||
# Otherwise use the normal service call limit.
|
||||
if domain == "script" and service != SERVICE_TURN_OFF:
|
||||
limit = None
|
||||
else:
|
||||
limit = SERVICE_CALL_LIMIT
|
||||
|
||||
coro = self._hass.services.async_call(
|
||||
domain,
|
||||
service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
# There is a call limit, so just wait for it to finish.
|
||||
await coro
|
||||
return
|
||||
|
||||
# No call limit (i.e., potentially starting one or more fully blocking scripts)
|
||||
# so watch for a stop request.
|
||||
done, pending = await asyncio.wait(
|
||||
{self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
# Note that cancelling the service call, if it has not yet returned, will also
|
||||
# stop any non-background script runs that it may have started.
|
||||
for pending_task in pending:
|
||||
pending_task.cancel()
|
||||
# Propagate any exceptions that might have happened.
|
||||
for done_task in done:
|
||||
done_task.result()
|
||||
|
||||
|
||||
class _QueuedScriptRun(_ScriptRun):
|
||||
"""Manage queued Script sequence run."""
|
||||
|
||||
lock_acquired = False
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
# Wait for previous run, if any, to finish by attempting to acquire the script's
|
||||
# shared lock. At the same time monitor if we've been told to stop.
|
||||
lock_task = self._hass.async_create_task(
|
||||
self._script._queue_lck.acquire() # pylint: disable=protected-access
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
{self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for pending_task in pending:
|
||||
pending_task.cancel()
|
||||
self.lock_acquired = lock_task in done
|
||||
|
||||
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
|
||||
# lock so we can go ahead and start the run.
|
||||
if self._stop.is_set():
|
||||
self._finish()
|
||||
else:
|
||||
await super().async_run()
|
||||
|
||||
class _BlockingScriptRun(_ScriptRun):
|
||||
"""Manage blocking Script sequence run."""
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
try:
|
||||
await asyncio.shield(self._async_run())
|
||||
except asyncio.CancelledError:
|
||||
await self.async_stop()
|
||||
raise
|
||||
def _finish(self):
|
||||
# pylint: disable=protected-access
|
||||
self._script._queue_len -= 1
|
||||
if self.lock_acquired:
|
||||
self._script._queue_lck.release()
|
||||
self.lock_acquired = False
|
||||
super()._finish()
|
||||
|
||||
|
||||
class _LegacyScriptRun(_ScriptRunBase):
|
||||
@ -445,6 +524,7 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||
|
||||
async def _async_run(self, propagate_exceptions=True):
|
||||
if self._cur == -1:
|
||||
self._script.last_triggered = utcnow()
|
||||
self._log("Running script")
|
||||
self._cur = 0
|
||||
|
||||
@ -457,7 +537,7 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||
for self._step, self._action in islice(
|
||||
enumerate(self._script.sequence), self._cur, None
|
||||
):
|
||||
await self._async_step(not propagate_exceptions)
|
||||
await self._async_step(log_exceptions=not propagate_exceptions)
|
||||
except _StopScript:
|
||||
pass
|
||||
except _SuspendScript:
|
||||
@ -469,11 +549,12 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||
if propagate_exceptions:
|
||||
raise
|
||||
finally:
|
||||
if self._cur != -1:
|
||||
self._changed()
|
||||
_cur_was = self._cur
|
||||
if not suspended:
|
||||
self._script.last_action = None
|
||||
await self.async_stop()
|
||||
if _cur_was != -1:
|
||||
self._changed()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
@ -512,9 +593,9 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||
|
||||
@callback
|
||||
def async_script_timeout(now):
|
||||
"""Call after timeout is retrieve."""
|
||||
"""Call after timeout has expired."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
self._async_listener.remove(unsub_timeout)
|
||||
|
||||
# Check if we want to continue to execute
|
||||
# the script after the timeout
|
||||
@ -530,13 +611,19 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||
self._async_listener.append(unsub_wait)
|
||||
|
||||
if CONF_TIMEOUT in self._action:
|
||||
unsub = async_track_point_in_utc_time(
|
||||
unsub_timeout = async_track_point_in_utc_time(
|
||||
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
self._async_listener.append(unsub_timeout)
|
||||
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
await self._hass.services.async_call(
|
||||
*self._prep_call_service_step(), blocking=True, context=self._context
|
||||
)
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove listeners, if any."""
|
||||
for unsub in self._async_listener:
|
||||
@ -553,47 +640,60 @@ class Script:
|
||||
sequence: Sequence[Dict[str, Any]],
|
||||
name: Optional[str] = None,
|
||||
change_listener: Optional[Callable[..., Any]] = None,
|
||||
if_running: Optional[str] = None,
|
||||
run_mode: Optional[str] = None,
|
||||
script_mode: str = DEFAULT_SCRIPT_MODE,
|
||||
queue_max: int = DEFAULT_QUEUE_MAX,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
log_exceptions: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the script."""
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._hass = hass
|
||||
self.sequence = sequence
|
||||
template.attach(hass, self.sequence)
|
||||
self.name = name
|
||||
self._change_listener = change_listener
|
||||
self.change_listener = change_listener
|
||||
self._script_mode = script_mode
|
||||
if logger:
|
||||
self._logger = logger
|
||||
else:
|
||||
logger_name = __name__
|
||||
if name:
|
||||
logger_name = ".".join([logger_name, slugify(name)])
|
||||
self._logger = logging.getLogger(logger_name)
|
||||
self._log_exceptions = log_exceptions
|
||||
|
||||
self.last_action = None
|
||||
self.last_triggered: Optional[datetime] = None
|
||||
self.can_cancel = any(
|
||||
self.can_cancel = not self.is_legacy or any(
|
||||
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
|
||||
for action in self.sequence
|
||||
)
|
||||
if not if_running and not run_mode:
|
||||
self._if_running = IF_RUNNING_PARALLEL
|
||||
self._run_mode = RUN_MODE_LEGACY
|
||||
elif if_running and run_mode == RUN_MODE_LEGACY:
|
||||
self._raise('Cannot use if_running if run_mode is "legacy"')
|
||||
else:
|
||||
self._if_running = if_running or IF_RUNNING_CHOICES[0]
|
||||
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
|
||||
|
||||
self._runs: List[_ScriptRunBase] = []
|
||||
self._log_exceptions = log_exceptions
|
||||
if script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._queue_max = queue_max
|
||||
self._queue_len = 0
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
|
||||
def _changed(self):
|
||||
if self._change_listener:
|
||||
self._hass.async_add_job(self._change_listener)
|
||||
if self.change_listener:
|
||||
if is_callback(self.change_listener):
|
||||
self.change_listener()
|
||||
else:
|
||||
self._hass.async_add_job(self.change_listener)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Return true if script is on."""
|
||||
return len(self._runs) > 0
|
||||
|
||||
@property
|
||||
def is_legacy(self) -> bool:
|
||||
"""Return if using legacy mode."""
|
||||
return self._script_mode == SCRIPT_MODE_LEGACY
|
||||
|
||||
@property
|
||||
def referenced_devices(self):
|
||||
"""Return a set of referenced devices."""
|
||||
@ -626,7 +726,7 @@ class Script:
|
||||
action = cv.determine_script_action(step)
|
||||
|
||||
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
||||
data = step.get(service.CONF_SERVICE_DATA)
|
||||
data = step.get(CONF_SERVICE_DATA)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
@ -661,18 +761,26 @@ class Script:
|
||||
) -> None:
|
||||
"""Run script."""
|
||||
if self.is_running:
|
||||
if self._if_running == IF_RUNNING_IGNORE:
|
||||
if self._script_mode == SCRIPT_MODE_IGNORE:
|
||||
self._log("Skipping script")
|
||||
return
|
||||
|
||||
if self._if_running == IF_RUNNING_ERROR:
|
||||
self._raise("Already running")
|
||||
if self._if_running == IF_RUNNING_RESTART:
|
||||
self._log("Restarting script")
|
||||
await self.async_stop()
|
||||
if self._script_mode == SCRIPT_MODE_ERROR:
|
||||
raise AlreadyRunning
|
||||
|
||||
self.last_triggered = utcnow()
|
||||
if self._run_mode == RUN_MODE_LEGACY:
|
||||
if self._script_mode == SCRIPT_MODE_RESTART:
|
||||
self._log("Restarting script")
|
||||
await self.async_stop(update_state=False)
|
||||
elif self._script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._log(
|
||||
"Queueing script behind %i run%s",
|
||||
self._queue_len,
|
||||
"s" if self._queue_len > 1 else "",
|
||||
)
|
||||
if self._queue_len >= self._queue_max:
|
||||
raise QueueFull
|
||||
|
||||
if self.is_legacy:
|
||||
if self._runs:
|
||||
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
|
||||
else:
|
||||
@ -681,23 +789,31 @@ class Script:
|
||||
self._hass, self, variables, context, self._log_exceptions, shared
|
||||
)
|
||||
else:
|
||||
if self._run_mode == RUN_MODE_BACKGROUND:
|
||||
run = _BackgroundScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions
|
||||
)
|
||||
if self._script_mode != SCRIPT_MODE_QUEUE:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
run = _BlockingScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions
|
||||
)
|
||||
cls = _QueuedScriptRun
|
||||
self._queue_len += 1
|
||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
||||
self._runs.append(run)
|
||||
await run.async_run()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
try:
|
||||
if self.is_legacy:
|
||||
await run.async_run()
|
||||
else:
|
||||
await asyncio.shield(run.async_run())
|
||||
except asyncio.CancelledError:
|
||||
await run.async_stop()
|
||||
self._changed()
|
||||
raise
|
||||
|
||||
async def async_stop(self, update_state: bool = True) -> None:
|
||||
"""Stop running script."""
|
||||
if not self.is_running:
|
||||
return
|
||||
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
|
||||
self._changed()
|
||||
if update_state:
|
||||
self._changed()
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
if self.name:
|
||||
@ -706,9 +822,3 @@ class Script:
|
||||
self._logger.exception(msg, *args)
|
||||
else:
|
||||
self._logger.log(level, msg, *args)
|
||||
|
||||
def _raise(self, msg, *args, exception=None):
|
||||
if not exception:
|
||||
exception = exceptions.HomeAssistantError
|
||||
self._log(msg, *args, level=logging.ERROR)
|
||||
raise exception(msg % args)
|
||||
|
@ -56,12 +56,27 @@ async def async_call_from_config(
|
||||
hass, config, blocking=False, variables=None, validate_config=True, context=None
|
||||
):
|
||||
"""Call a service based on a config hash."""
|
||||
try:
|
||||
parms = async_prepare_call_from_config(hass, config, variables, validate_config)
|
||||
except HomeAssistantError as ex:
|
||||
if blocking:
|
||||
raise
|
||||
_LOGGER.error(ex)
|
||||
else:
|
||||
await hass.services.async_call(*parms, blocking, context)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@bind_hass
|
||||
def async_prepare_call_from_config(hass, config, variables=None, validate_config=False):
|
||||
"""Prepare to call a service based on a config hash."""
|
||||
if validate_config:
|
||||
try:
|
||||
config = cv.SERVICE_SCHEMA(config)
|
||||
except vol.Invalid as ex:
|
||||
_LOGGER.error("Invalid config for calling service: %s", ex)
|
||||
return
|
||||
raise HomeAssistantError(
|
||||
f"Invalid config for calling service: {ex}"
|
||||
) from ex
|
||||
|
||||
if CONF_SERVICE in config:
|
||||
domain_service = config[CONF_SERVICE]
|
||||
@ -71,17 +86,15 @@ async def async_call_from_config(
|
||||
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables)
|
||||
domain_service = cv.service(domain_service)
|
||||
except TemplateError as ex:
|
||||
if blocking:
|
||||
raise
|
||||
_LOGGER.error("Error rendering service name template: %s", ex)
|
||||
return
|
||||
except vol.Invalid:
|
||||
if blocking:
|
||||
raise
|
||||
_LOGGER.error("Template rendered invalid service: %s", domain_service)
|
||||
return
|
||||
raise HomeAssistantError(
|
||||
f"Error rendering service name template: {ex}"
|
||||
) from ex
|
||||
except vol.Invalid as ex:
|
||||
raise HomeAssistantError(
|
||||
f"Template rendered invalid service: {domain_service}"
|
||||
) from ex
|
||||
|
||||
domain, service_name = domain_service.split(".", 1)
|
||||
domain, service = domain_service.split(".", 1)
|
||||
service_data = dict(config.get(CONF_SERVICE_DATA, {}))
|
||||
|
||||
if CONF_SERVICE_DATA_TEMPLATE in config:
|
||||
@ -91,15 +104,12 @@ async def async_call_from_config(
|
||||
template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables)
|
||||
)
|
||||
except TemplateError as ex:
|
||||
_LOGGER.error("Error rendering data template: %s", ex)
|
||||
return
|
||||
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
|
||||
|
||||
if CONF_SERVICE_ENTITY_ID in config:
|
||||
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||
|
||||
await hass.services.async_call(
|
||||
domain, service_name, service_data, blocking=blocking, context=context
|
||||
)
|
||||
return domain, service, service_data
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user