From 5f5cb8bea8b92ebb0f30ed30b680b82d5f441250 Mon Sep 17 00:00:00 2001 From: Phil Bruckner Date: Wed, 11 Mar 2020 18:34:50 -0500 Subject: [PATCH] Add support for simultaneous runs of Script helper - Part 2 (#32442) * Add limit parameter to service call methods * Break out prep part of async_call_from_config for use elsewhere * Minor cleanup * Fix improper use of asyncio.wait * Fix state update Call change listener immediately if its a callback * Fix exception handling and logging * Merge Script helper if_running/run_mode parameters into script_mode - Remove background/blocking _ScriptRun subclasses which are no longer needed. * Add queued script mode * Disable timeout when making fully blocking script call * Don't call change listener when restarting script This makes restart mode behavior consistent with parallel & queue modes. * Changes per review - Call all script services (except script.turn_off) with no time limit. - Fix handling of lock in _QueuedScriptRun and add comments to make it clearer how this code works. * Changes per review 2 - Move cancel shielding "up" from _ScriptRun.async_run to Script.async_run (and apply to new style scripts only.) This makes sure Script class also properly handles cancellation which it wasn't doing before. - In _ScriptRun._async_call_service_step, instead of using script.turn_off service, just cancel service call and let it handle the cancellation accordingly. * Fix bugs - Add missing call to change listener in Script.async_run in cancelled path. - Cancel service task if ServiceRegistry.async_call cancelled. * Revert last changes to ServiceRegistry.async_call * Minor Script helper fixes & test improvements - Don't log asyncio.CancelledError exceptions. - Make change_listener a public attribute. - Test overhaul - Parametrize tests. - Use common test functions. - Mock timeout so tests don't need to wait for real time to elapse. - Add common function for waiting for script action step. --- homeassistant/core.py | 37 +- homeassistant/helpers/script.py | 362 ++++-- homeassistant/helpers/service.py | 44 +- tests/helpers/test_script.py | 1949 ++++++++++-------------------- 4 files changed, 948 insertions(+), 1444 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index a1d9a83d1ad..afd1e4daa1a 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -93,7 +93,7 @@ SOURCE_DISCOVERED = "discovered" SOURCE_STORAGE = "storage" SOURCE_YAML = "yaml" -# How long to wait till things that run on startup have to finish. +# How long to wait until things that run on startup have to finish. TIMEOUT_EVENT_START = 15 _LOGGER = logging.getLogger(__name__) @@ -249,7 +249,7 @@ class HomeAssistant: try: # Only block for EVENT_HOMEASSISTANT_START listener self.async_stop_track_tasks() - with timeout(TIMEOUT_EVENT_START): + async with timeout(TIMEOUT_EVENT_START): await self.async_block_till_done() except asyncio.TimeoutError: _LOGGER.warning( @@ -374,13 +374,13 @@ class HomeAssistant: self.async_add_job(target, *args) def block_till_done(self) -> None: - """Block till all pending work is done.""" + """Block until all pending work is done.""" asyncio.run_coroutine_threadsafe( self.async_block_till_done(), self.loop ).result() async def async_block_till_done(self) -> None: - """Block till all pending work is done.""" + """Block until all pending work is done.""" # To flush out any call_soon_threadsafe await asyncio.sleep(0) @@ -1150,25 +1150,15 @@ class ServiceRegistry: service_data: Optional[Dict] = None, blocking: bool = False, context: Optional[Context] = None, + limit: Optional[float] = SERVICE_CALL_LIMIT, ) -> Optional[bool]: """ Call a service. - Specify blocking=True to wait till service is executed. - Waits a maximum of SERVICE_CALL_LIMIT. - - If blocking = True, will return boolean if service executed - successfully within SERVICE_CALL_LIMIT. - - This method will fire an event to call the service. - This event will be picked up by this ServiceRegistry and any - other ServiceRegistry that is listening on the EventBus. - - Because the service is sent as an event you are not allowed to use - the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. + See description of async_call for details. """ return asyncio.run_coroutine_threadsafe( - self.async_call(domain, service, service_data, blocking, context), + self.async_call(domain, service, service_data, blocking, context, limit), self._hass.loop, ).result() @@ -1179,19 +1169,18 @@ class ServiceRegistry: service_data: Optional[Dict] = None, blocking: bool = False, context: Optional[Context] = None, + limit: Optional[float] = SERVICE_CALL_LIMIT, ) -> Optional[bool]: """ Call a service. - Specify blocking=True to wait till service is executed. - Waits a maximum of SERVICE_CALL_LIMIT. + Specify blocking=True to wait until service is executed. + Waits a maximum of limit, which may be None for no timeout. If blocking = True, will return boolean if service executed - successfully within SERVICE_CALL_LIMIT. + successfully within limit. - This method will fire an event to call the service. - This event will be picked up by this ServiceRegistry and any - other ServiceRegistry that is listening on the EventBus. + This method will fire an event to indicate the service has been called. Because the service is sent as an event you are not allowed to use the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. @@ -1230,7 +1219,7 @@ class ServiceRegistry: return None try: - with timeout(SERVICE_CALL_LIMIT): + async with timeout(limit): await asyncio.shield(self._execute_service(handler, service_call)) return True except asyncio.TimeoutError: diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 937a675aada..7d1088eebe4 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -7,6 +7,7 @@ from itertools import islice import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast +from async_timeout import timeout import voluptuous as vol from homeassistant import exceptions @@ -14,6 +15,7 @@ import homeassistant.components.device_automation as device_automation import homeassistant.components.scene as scene from homeassistant.const import ( ATTR_ENTITY_ID, + CONF_ALIAS, CONF_CONDITION, CONF_CONTINUE_ON_TIMEOUT, CONF_DELAY, @@ -25,47 +27,53 @@ from homeassistant.const import ( CONF_SCENE, CONF_TIMEOUT, CONF_WAIT_TEMPLATE, + SERVICE_TURN_OFF, SERVICE_TURN_ON, ) -from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback +from homeassistant.core import ( + CALLBACK_TYPE, + SERVICE_CALL_LIMIT, + Context, + HomeAssistant, + callback, + is_callback, +) from homeassistant.helpers import ( condition, config_validation as cv, - service, template as template, ) from homeassistant.helpers.event import ( async_track_point_in_utc_time, async_track_template, ) +from homeassistant.helpers.service import ( + CONF_SERVICE_DATA, + async_prepare_call_from_config, +) from homeassistant.helpers.typing import ConfigType +from homeassistant.util import slugify from homeassistant.util.dt import utcnow # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs -CONF_ALIAS = "alias" - -IF_RUNNING_ERROR = "error" -IF_RUNNING_IGNORE = "ignore" -IF_RUNNING_PARALLEL = "parallel" -IF_RUNNING_RESTART = "restart" -# First choice is default -IF_RUNNING_CHOICES = [ - IF_RUNNING_PARALLEL, - IF_RUNNING_ERROR, - IF_RUNNING_IGNORE, - IF_RUNNING_RESTART, +SCRIPT_MODE_ERROR = "error" +SCRIPT_MODE_IGNORE = "ignore" +SCRIPT_MODE_LEGACY = "legacy" +SCRIPT_MODE_PARALLEL = "parallel" +SCRIPT_MODE_QUEUE = "queue" +SCRIPT_MODE_RESTART = "restart" +SCRIPT_MODE_CHOICES = [ + SCRIPT_MODE_ERROR, + SCRIPT_MODE_IGNORE, + SCRIPT_MODE_LEGACY, + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUE, + SCRIPT_MODE_RESTART, ] +DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY -RUN_MODE_BACKGROUND = "background" -RUN_MODE_BLOCKING = "blocking" -RUN_MODE_LEGACY = "legacy" -# First choice is default -RUN_MODE_CHOICES = [ - RUN_MODE_BLOCKING, - RUN_MODE_BACKGROUND, - RUN_MODE_LEGACY, -] +DEFAULT_QUEUE_MAX = 10 _LOG_EXCEPTION = logging.ERROR + 1 _TIMEOUT_MSG = "Timeout reached, abort script." @@ -102,6 +110,14 @@ class _SuspendScript(Exception): """Throw if script needs to suspend.""" +class AlreadyRunning(exceptions.HomeAssistantError): + """Throw if script already running and user wants error.""" + + +class QueueFull(exceptions.HomeAssistantError): + """Throw if script already running, user wants new run queued, but queue is full.""" + + class _ScriptRunBase(ABC): """Common data & methods for managing Script sequence run.""" @@ -137,11 +153,11 @@ class _ScriptRunBase(ABC): await getattr( self, f"_async_{cv.determine_script_action(self._action)}_step" )() - except Exception as err: - if not isinstance(err, (_SuspendScript, _StopScript)) and ( - self._log_exceptions or log_exceptions - ): - self._log_exception(err) + except Exception as ex: + if not isinstance( + ex, (_SuspendScript, _StopScript, asyncio.CancelledError) + ) and (self._log_exceptions or log_exceptions): + self._log_exception(ex) raise @abstractmethod @@ -166,6 +182,12 @@ class _ScriptRunBase(ABC): elif isinstance(exception, exceptions.ServiceNotFound): error_desc = "Service not found" + elif isinstance(exception, AlreadyRunning): + error_desc = "Already running" + + elif isinstance(exception, QueueFull): + error_desc = "Run queue is full" + else: error_desc = "Unexpected error" level = _LOG_EXCEPTION @@ -189,12 +211,13 @@ class _ScriptRunBase(ABC): template.render_complex(self._action[CONF_DELAY], self._variables) ) except (exceptions.TemplateError, vol.Invalid) as ex: - self._raise( + self._log( "Error rendering %s delay template: %s", self._script.name, ex, - exception=_StopScript, + level=logging.ERROR, ) + raise _StopScript self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}") self._log("Executing step %s", self._script.last_action) @@ -220,18 +243,14 @@ class _ScriptRunBase(ABC): self._hass, wait_template, async_script_wait, self._variables ) + @abstractmethod async def _async_call_service_step(self): """Call the service specified in the action.""" + + def _prep_call_service_step(self): self._script.last_action = self._action.get(CONF_ALIAS, "call service") self._log("Executing step %s", self._script.last_action) - await service.async_call_from_config( - self._hass, - self._action, - blocking=True, - variables=self._variables, - validate_config=False, - context=self._context, - ) + return async_prepare_call_from_config(self._hass, self._action, self._variables) async def _async_device_step(self): """Perform the device automation specified in the action.""" @@ -298,10 +317,6 @@ class _ScriptRunBase(ABC): def _log(self, msg, *args, level=logging.INFO): self._script._log(msg, *args, level=level) # pylint: disable=protected-access - def _raise(self, msg, *args, exception=None): - # pylint: disable=protected-access - self._script._raise(msg, *args, exception=exception) - class _ScriptRun(_ScriptRunBase): """Manage Script sequence run.""" @@ -318,24 +333,33 @@ class _ScriptRun(_ScriptRunBase): self._stop = asyncio.Event() self._stopped = asyncio.Event() - async def _async_run(self, propagate_exceptions=True): - self._log("Running script") + def _changed(self): + if not self._stop.is_set(): + super()._changed() + + async def async_run(self) -> None: + """Run script.""" try: + if self._stop.is_set(): + return + self._script.last_triggered = utcnow() + self._changed() + self._log("Running script") for self._step, self._action in enumerate(self._script.sequence): if self._stop.is_set(): break - await self._async_step(not propagate_exceptions) + await self._async_step(log_exceptions=False) except _StopScript: pass - except Exception: # pylint: disable=broad-except - if propagate_exceptions: - raise finally: - if not self._stop.is_set(): - self._changed() + self._finish() + + def _finish(self): + self._script._runs.remove(self) # pylint: disable=protected-access + if not self._script.is_running: self._script.last_action = None - self._script._runs.remove(self) # pylint: disable=protected-access - self._stopped.set() + self._changed() + self._stopped.set() async def async_stop(self) -> None: """Stop script run.""" @@ -344,10 +368,13 @@ class _ScriptRun(_ScriptRunBase): async def _async_delay_step(self): """Handle delay.""" - timeout = self._prep_delay_step().total_seconds() - if not self._stop.is_set(): - self._changed() - await asyncio.wait({self._stop.wait()}, timeout=timeout) + delay = self._prep_delay_step().total_seconds() + self._changed() + try: + async with timeout(delay): + await self._stop.wait() + except asyncio.TimeoutError: + pass async def _async_wait_template_step(self): """Handle a wait template.""" @@ -361,21 +388,20 @@ class _ScriptRun(_ScriptRunBase): if not unsub: return - if not self._stop.is_set(): - self._changed() + self._changed() try: - timeout = self._action[CONF_TIMEOUT].total_seconds() + delay = self._action[CONF_TIMEOUT].total_seconds() except KeyError: - timeout = None + delay = None done = asyncio.Event() try: - await asyncio.wait_for( - asyncio.wait( + async with timeout(delay): + _, pending = await asyncio.wait( {self._stop.wait(), done.wait()}, return_when=asyncio.FIRST_COMPLETED, - ), - timeout, - ) + ) + for pending_task in pending: + pending_task.cancel() except asyncio.TimeoutError: if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) @@ -383,25 +409,78 @@ class _ScriptRun(_ScriptRunBase): finally: unsub() + async def _async_call_service_step(self): + """Call the service specified in the action.""" + domain, service, service_data = self._prep_call_service_step() -class _BackgroundScriptRun(_ScriptRun): - """Manage background Script sequence run.""" + # If this might start a script then disable the call timeout. + # Otherwise use the normal service call limit. + if domain == "script" and service != SERVICE_TURN_OFF: + limit = None + else: + limit = SERVICE_CALL_LIMIT + + coro = self._hass.services.async_call( + domain, + service, + service_data, + blocking=True, + context=self._context, + limit=limit, + ) + + if limit is not None: + # There is a call limit, so just wait for it to finish. + await coro + return + + # No call limit (i.e., potentially starting one or more fully blocking scripts) + # so watch for a stop request. + done, pending = await asyncio.wait( + {self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED, + ) + # Note that cancelling the service call, if it has not yet returned, will also + # stop any non-background script runs that it may have started. + for pending_task in pending: + pending_task.cancel() + # Propagate any exceptions that might have happened. + for done_task in done: + done_task.result() + + +class _QueuedScriptRun(_ScriptRun): + """Manage queued Script sequence run.""" + + lock_acquired = False async def async_run(self) -> None: """Run script.""" - self._hass.async_create_task(self._async_run(False)) + # Wait for previous run, if any, to finish by attempting to acquire the script's + # shared lock. At the same time monitor if we've been told to stop. + lock_task = self._hass.async_create_task( + self._script._queue_lck.acquire() # pylint: disable=protected-access + ) + done, pending = await asyncio.wait( + {self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED + ) + for pending_task in pending: + pending_task.cancel() + self.lock_acquired = lock_task in done + # If we've been told to stop, then just finish up. Otherwise, we've acquired the + # lock so we can go ahead and start the run. + if self._stop.is_set(): + self._finish() + else: + await super().async_run() -class _BlockingScriptRun(_ScriptRun): - """Manage blocking Script sequence run.""" - - async def async_run(self) -> None: - """Run script.""" - try: - await asyncio.shield(self._async_run()) - except asyncio.CancelledError: - await self.async_stop() - raise + def _finish(self): + # pylint: disable=protected-access + self._script._queue_len -= 1 + if self.lock_acquired: + self._script._queue_lck.release() + self.lock_acquired = False + super()._finish() class _LegacyScriptRun(_ScriptRunBase): @@ -445,6 +524,7 @@ class _LegacyScriptRun(_ScriptRunBase): async def _async_run(self, propagate_exceptions=True): if self._cur == -1: + self._script.last_triggered = utcnow() self._log("Running script") self._cur = 0 @@ -457,7 +537,7 @@ class _LegacyScriptRun(_ScriptRunBase): for self._step, self._action in islice( enumerate(self._script.sequence), self._cur, None ): - await self._async_step(not propagate_exceptions) + await self._async_step(log_exceptions=not propagate_exceptions) except _StopScript: pass except _SuspendScript: @@ -469,11 +549,12 @@ class _LegacyScriptRun(_ScriptRunBase): if propagate_exceptions: raise finally: - if self._cur != -1: - self._changed() + _cur_was = self._cur if not suspended: self._script.last_action = None await self.async_stop() + if _cur_was != -1: + self._changed() async def async_stop(self) -> None: """Stop script run.""" @@ -512,9 +593,9 @@ class _LegacyScriptRun(_ScriptRunBase): @callback def async_script_timeout(now): - """Call after timeout is retrieve.""" + """Call after timeout has expired.""" with suppress(ValueError): - self._async_listener.remove(unsub) + self._async_listener.remove(unsub_timeout) # Check if we want to continue to execute # the script after the timeout @@ -530,13 +611,19 @@ class _LegacyScriptRun(_ScriptRunBase): self._async_listener.append(unsub_wait) if CONF_TIMEOUT in self._action: - unsub = async_track_point_in_utc_time( + unsub_timeout = async_track_point_in_utc_time( self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT] ) - self._async_listener.append(unsub) + self._async_listener.append(unsub_timeout) raise _SuspendScript + async def _async_call_service_step(self): + """Call the service specified in the action.""" + await self._hass.services.async_call( + *self._prep_call_service_step(), blocking=True, context=self._context + ) + def _async_remove_listener(self): """Remove listeners, if any.""" for unsub in self._async_listener: @@ -553,47 +640,60 @@ class Script: sequence: Sequence[Dict[str, Any]], name: Optional[str] = None, change_listener: Optional[Callable[..., Any]] = None, - if_running: Optional[str] = None, - run_mode: Optional[str] = None, + script_mode: str = DEFAULT_SCRIPT_MODE, + queue_max: int = DEFAULT_QUEUE_MAX, logger: Optional[logging.Logger] = None, log_exceptions: bool = True, ) -> None: """Initialize the script.""" - self._logger = logger or logging.getLogger(__name__) self._hass = hass self.sequence = sequence template.attach(hass, self.sequence) self.name = name - self._change_listener = change_listener + self.change_listener = change_listener + self._script_mode = script_mode + if logger: + self._logger = logger + else: + logger_name = __name__ + if name: + logger_name = ".".join([logger_name, slugify(name)]) + self._logger = logging.getLogger(logger_name) + self._log_exceptions = log_exceptions + self.last_action = None self.last_triggered: Optional[datetime] = None - self.can_cancel = any( + self.can_cancel = not self.is_legacy or any( CONF_DELAY in action or CONF_WAIT_TEMPLATE in action for action in self.sequence ) - if not if_running and not run_mode: - self._if_running = IF_RUNNING_PARALLEL - self._run_mode = RUN_MODE_LEGACY - elif if_running and run_mode == RUN_MODE_LEGACY: - self._raise('Cannot use if_running if run_mode is "legacy"') - else: - self._if_running = if_running or IF_RUNNING_CHOICES[0] - self._run_mode = run_mode or RUN_MODE_CHOICES[0] + self._runs: List[_ScriptRunBase] = [] - self._log_exceptions = log_exceptions + if script_mode == SCRIPT_MODE_QUEUE: + self._queue_max = queue_max + self._queue_len = 0 + self._queue_lck = asyncio.Lock() self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None def _changed(self): - if self._change_listener: - self._hass.async_add_job(self._change_listener) + if self.change_listener: + if is_callback(self.change_listener): + self.change_listener() + else: + self._hass.async_add_job(self.change_listener) @property def is_running(self) -> bool: """Return true if script is on.""" return len(self._runs) > 0 + @property + def is_legacy(self) -> bool: + """Return if using legacy mode.""" + return self._script_mode == SCRIPT_MODE_LEGACY + @property def referenced_devices(self): """Return a set of referenced devices.""" @@ -626,7 +726,7 @@ class Script: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: - data = step.get(service.CONF_SERVICE_DATA) + data = step.get(CONF_SERVICE_DATA) if not data: continue @@ -661,18 +761,26 @@ class Script: ) -> None: """Run script.""" if self.is_running: - if self._if_running == IF_RUNNING_IGNORE: + if self._script_mode == SCRIPT_MODE_IGNORE: self._log("Skipping script") return - if self._if_running == IF_RUNNING_ERROR: - self._raise("Already running") - if self._if_running == IF_RUNNING_RESTART: - self._log("Restarting script") - await self.async_stop() + if self._script_mode == SCRIPT_MODE_ERROR: + raise AlreadyRunning - self.last_triggered = utcnow() - if self._run_mode == RUN_MODE_LEGACY: + if self._script_mode == SCRIPT_MODE_RESTART: + self._log("Restarting script") + await self.async_stop(update_state=False) + elif self._script_mode == SCRIPT_MODE_QUEUE: + self._log( + "Queueing script behind %i run%s", + self._queue_len, + "s" if self._queue_len > 1 else "", + ) + if self._queue_len >= self._queue_max: + raise QueueFull + + if self.is_legacy: if self._runs: shared = cast(Optional[_LegacyScriptRun], self._runs[0]) else: @@ -681,23 +789,31 @@ class Script: self._hass, self, variables, context, self._log_exceptions, shared ) else: - if self._run_mode == RUN_MODE_BACKGROUND: - run = _BackgroundScriptRun( - self._hass, self, variables, context, self._log_exceptions - ) + if self._script_mode != SCRIPT_MODE_QUEUE: + cls = _ScriptRun else: - run = _BlockingScriptRun( - self._hass, self, variables, context, self._log_exceptions - ) + cls = _QueuedScriptRun + self._queue_len += 1 + run = cls(self._hass, self, variables, context, self._log_exceptions) self._runs.append(run) - await run.async_run() - async def async_stop(self) -> None: + try: + if self.is_legacy: + await run.async_run() + else: + await asyncio.shield(run.async_run()) + except asyncio.CancelledError: + await run.async_stop() + self._changed() + raise + + async def async_stop(self, update_state: bool = True) -> None: """Stop running script.""" if not self.is_running: return await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs))) - self._changed() + if update_state: + self._changed() def _log(self, msg, *args, level=logging.INFO): if self.name: @@ -706,9 +822,3 @@ class Script: self._logger.exception(msg, *args) else: self._logger.log(level, msg, *args) - - def _raise(self, msg, *args, exception=None): - if not exception: - exception = exceptions.HomeAssistantError - self._log(msg, *args, level=logging.ERROR) - raise exception(msg % args) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 578d5368314..7a352b4e8d1 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -56,12 +56,27 @@ async def async_call_from_config( hass, config, blocking=False, variables=None, validate_config=True, context=None ): """Call a service based on a config hash.""" + try: + parms = async_prepare_call_from_config(hass, config, variables, validate_config) + except HomeAssistantError as ex: + if blocking: + raise + _LOGGER.error(ex) + else: + await hass.services.async_call(*parms, blocking, context) + + +@ha.callback +@bind_hass +def async_prepare_call_from_config(hass, config, variables=None, validate_config=False): + """Prepare to call a service based on a config hash.""" if validate_config: try: config = cv.SERVICE_SCHEMA(config) except vol.Invalid as ex: - _LOGGER.error("Invalid config for calling service: %s", ex) - return + raise HomeAssistantError( + f"Invalid config for calling service: {ex}" + ) from ex if CONF_SERVICE in config: domain_service = config[CONF_SERVICE] @@ -71,17 +86,15 @@ async def async_call_from_config( domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables) domain_service = cv.service(domain_service) except TemplateError as ex: - if blocking: - raise - _LOGGER.error("Error rendering service name template: %s", ex) - return - except vol.Invalid: - if blocking: - raise - _LOGGER.error("Template rendered invalid service: %s", domain_service) - return + raise HomeAssistantError( + f"Error rendering service name template: {ex}" + ) from ex + except vol.Invalid as ex: + raise HomeAssistantError( + f"Template rendered invalid service: {domain_service}" + ) from ex - domain, service_name = domain_service.split(".", 1) + domain, service = domain_service.split(".", 1) service_data = dict(config.get(CONF_SERVICE_DATA, {})) if CONF_SERVICE_DATA_TEMPLATE in config: @@ -91,15 +104,12 @@ async def async_call_from_config( template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables) ) except TemplateError as ex: - _LOGGER.error("Error rendering data template: %s", ex) - return + raise HomeAssistantError(f"Error rendering data template: {ex}") from ex if CONF_SERVICE_ENTITY_ID in config: service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] - await hass.services.async_call( - domain, service_name, service_data, blocking=blocking, context=context - ) + return domain, service, service_data @bind_hass diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 443b131b2aa..eb1d5e15020 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1,6 +1,7 @@ """The tests for the Script component.""" # pylint: disable=protected-access import asyncio +from contextlib import contextmanager from datetime import timedelta import logging from unittest import mock @@ -15,63 +16,106 @@ import homeassistant.components.scene as scene from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.core import Context, callback from homeassistant.helpers import config_validation as cv, script +from homeassistant.helpers.event import async_call_later import homeassistant.util.dt as dt_util -from tests.common import async_fire_time_changed +from tests.common import ( + async_capture_events, + async_fire_time_changed, + async_mock_service, +) ENTITY_ID = "script.test" -_ALL_RUN_MODES = [None, "background", "blocking"] +_BASIC_SCRIPT_MODES = ("legacy", "parallel") -async def test_firing_event_basic(hass): +@pytest.fixture +def mock_timeout(hass, monkeypatch): + """Mock async_timeout.timeout.""" + + class MockTimeout: + def __init__(self, timeout): + self._timeout = timeout + self._loop = asyncio.get_event_loop() + self._task = None + self._cancelled = False + self._unsub = None + + async def __aenter__(self): + if self._timeout is None: + return self + self._task = asyncio.Task.current_task() + if self._timeout <= 0: + self._loop.call_soon(self._cancel_task) + return self + # Wait for a time_changed event instead of real time passing. + self._unsub = async_call_later(hass, self._timeout, self._cancel_task) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is asyncio.CancelledError and self._cancelled: + self._unsub = None + self._task = None + raise asyncio.TimeoutError + if self._timeout is not None and self._unsub: + self._unsub() + self._unsub = None + self._task = None + return None + + @callback + def _cancel_task(self, now=None): + if self._task is not None: + self._task.cancel() + self._cancelled = True + + monkeypatch.setattr(script, "timeout", MockTimeout) + + +def async_watch_for_action(script_obj, message): + """Watch for message in last_action.""" + flag = asyncio.Event() + + @callback + def check_action(): + if script_obj.last_action and message in script_obj.last_action: + flag.set() + + script_obj.change_listener = check_action + return flag + + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_firing_event_basic(hass, script_mode): """Test the firing of events.""" event = "test_event" context = Context() + events = async_capture_events(hass, event) - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) + sequence = cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - hass.bus.async_listen(event, record_event) + assert script_obj.is_legacy == (script_mode == "legacy") + assert script_obj.can_cancel == (script_mode != "legacy") - schema = cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}}) + await script_obj.async_run(context=context) + await hass.async_block_till_done() - # For this one test we'll make sure "legacy" works the same as None. - for run_mode in _ALL_RUN_MODES + ["legacy"]: - events = [] - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - assert not script_obj.can_cancel - - await script_obj.async_run(context=context) - - await hass.async_block_till_done() - - assert len(events) == 1 - assert events[0].context is context - assert events[0].data.get("hello") == "world" - assert not script_obj.can_cancel + assert len(events) == 1 + assert events[0].context is context + assert events[0].data.get("hello") == "world" + assert script_obj.can_cancel == (script_mode != "legacy") -async def test_firing_event_template(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_firing_event_template(hass, script_mode): """Test the firing of events.""" event = "test_event" context = Context() + events = async_capture_events(hass, event) - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA( + sequence = cv.SCRIPT_SCHEMA( { "event": event, "event_data_template": { @@ -84,152 +128,47 @@ async def test_firing_event_template(hass): }, } ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - for run_mode in _ALL_RUN_MODES: - events = [] + assert script_obj.can_cancel == (script_mode != "legacy") - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) + await script_obj.async_run({"is_world": "yes"}, context=context) + await hass.async_block_till_done() - assert not script_obj.can_cancel - - await script_obj.async_run({"is_world": "yes"}, context=context) - - await hass.async_block_till_done() - - assert len(events) == 1 - assert events[0].context is context - assert events[0].data == { - "dict": {1: "yes", 2: "yesyes", 3: "yesyesyes"}, - "list": ["yes", "yesyes"], - } + assert len(events) == 1 + assert events[0].context is context + assert events[0].data == { + "dict": {1: "yes", 2: "yesyes", 3: "yesyesyes"}, + "list": ["yes", "yesyes"], + } -async def test_calling_service_basic(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_calling_service_basic(hass, script_mode): """Test the calling of a service.""" context = Context() + calls = async_mock_service(hass, "test", "script") - @callback - def record_call(service): - """Add recorded event to set.""" - calls.append(service) + sequence = cv.SCRIPT_SCHEMA({"service": "test.script", "data": {"hello": "world"}}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - hass.services.async_register("test", "script", record_call) + assert script_obj.can_cancel == (script_mode != "legacy") - schema = cv.SCRIPT_SCHEMA({"service": "test.script", "data": {"hello": "world"}}) + await script_obj.async_run(context=context) + await hass.async_block_till_done() - for run_mode in _ALL_RUN_MODES: - calls = [] - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - assert not script_obj.can_cancel - - await script_obj.async_run(context=context) - - await hass.async_block_till_done() - - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get("hello") == "world" + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get("hello") == "world" -async def test_cancel_no_wait(hass, caplog): - """Test stopping script.""" - event = "test_event" - - async def async_simulate_long_service(service): - """Simulate a service that takes a not insignificant time.""" - await asyncio.sleep(0.01) - - hass.services.async_register("test", "script", async_simulate_long_service) - - @callback - def monitor_event(event): - """Signal event happened.""" - event_sem.release() - - hass.bus.async_listen(event, monitor_event) - - schema = cv.SCRIPT_SCHEMA([{"event": event}, {"service": "test.script"}]) - - for run_mode in _ALL_RUN_MODES: - event_sem = asyncio.Semaphore(0) - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - tasks = [] - for _ in range(3): - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - tasks.append(hass.async_create_task(event_sem.acquire())) - await asyncio.wait_for(asyncio.gather(*tasks), 1) - - # Can't assert just yet because we haven't verified stopping works yet. - # If assert fails we can hang test if async_stop doesn't work. - script_was_runing = script_obj.is_running - - await script_obj.async_stop() - await hass.async_block_till_done() - - assert script_was_runing - assert not script_obj.is_running - - -async def test_activating_scene(hass): - """Test the activation of a scene.""" - context = Context() - - @callback - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - hass.services.async_register(scene.DOMAIN, SERVICE_TURN_ON, record_call) - - schema = cv.SCRIPT_SCHEMA({"scene": "scene.hello"}) - - for run_mode in _ALL_RUN_MODES: - calls = [] - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - assert not script_obj.can_cancel - - await script_obj.async_run(context=context) - - await hass.async_block_till_done() - - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get(ATTR_ENTITY_ID) == "scene.hello" - - -async def test_calling_service_template(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_calling_service_template(hass, script_mode): """Test the calling of a service.""" context = Context() + calls = async_mock_service(hass, "test", "script") - @callback - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - hass.services.async_register("test", "script", record_call) - - schema = cv.SCRIPT_SCHEMA( + sequence = cv.SCRIPT_SCHEMA( { "service_template": """ {% if True %} @@ -248,32 +187,30 @@ async def test_calling_service_template(hass): }, } ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - for run_mode in _ALL_RUN_MODES: - calls = [] + assert script_obj.can_cancel == (script_mode != "legacy") - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) + await script_obj.async_run({"is_world": "yes"}, context=context) + await hass.async_block_till_done() - assert not script_obj.can_cancel - - await script_obj.async_run({"is_world": "yes"}, context=context) - - await hass.async_block_till_done() - - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get("hello") == "world" + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get("hello") == "world" -async def test_multiple_runs_no_wait(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_multiple_runs_no_wait(hass, script_mode): """Test multiple runs with no wait in script.""" logger = logging.getLogger("TEST") + calls = [] + heard_event = asyncio.Event() async def async_simulate_long_service(service): """Simulate a service that takes a not insignificant time.""" + fire = service.data.get("fire") + listen = service.data.get("listen") + service_done = asyncio.Event() @callback def service_done_cb(event): @@ -281,29 +218,20 @@ async def test_multiple_runs_no_wait(hass): service_done.set() calls.append(service) - - fire = service.data.get("fire") - listen = service.data.get("listen") logger.debug("simulated service (%s:%s) started", fire, listen) - - service_done = asyncio.Event() unsub = hass.bus.async_listen(listen, service_done_cb) - hass.bus.async_fire(fire) - await service_done.wait() unsub() hass.services.async_register("test", "script", async_simulate_long_service) - heard_event = asyncio.Event() - @callback def heard_event_cb(event): logger.debug("heard: %s", event) heard_event.set() - schema = cv.SCRIPT_SCHEMA( + sequence = cv.SCRIPT_SCHEMA( [ { "service": "test.script", @@ -315,209 +243,199 @@ async def test_multiple_runs_no_wait(hass): }, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - for run_mode in _ALL_RUN_MODES: - calls = [] - heard_event.clear() + # Start script twice in such a way that second run will be started while first run + # is in the middle of the first service call. - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - # Start script twice in such a way that second run will be started while first - # run is in the middle of the first service call. - - unsub = hass.bus.async_listen("1", heard_event_cb) - - logger.debug("starting 1st script") - coro = script_obj.async_run( + unsub = hass.bus.async_listen("1", heard_event_cb) + logger.debug("starting 1st script") + hass.async_create_task( + script_obj.async_run( {"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"} ) - if run_mode == "background": - await coro - else: - hass.async_create_task(coro) - await asyncio.wait_for(heard_event.wait(), 1) + ) + await asyncio.wait_for(heard_event.wait(), 1) + unsub() - unsub() + logger.debug("starting 2nd script") + await script_obj.async_run( + {"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"} + ) + await hass.async_block_till_done() - logger.debug("starting 2nd script") - await script_obj.async_run( - {"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"} - ) - - await hass.async_block_till_done() - - assert len(calls) == 4 + assert len(calls) == 4 -async def test_delay_basic(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_activating_scene(hass, script_mode): + """Test the activation of a scene.""" + context = Context() + calls = async_mock_service(hass, scene.DOMAIN, SERVICE_TURN_ON) + + sequence = cv.SCRIPT_SCHEMA({"scene": "scene.hello"}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + + assert script_obj.can_cancel == (script_mode != "legacy") + + await script_obj.async_run(context=context) + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get(ATTR_ENTITY_ID) == "scene.hello" + + +@pytest.mark.parametrize("count", [1, 3]) +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_stop_no_wait(hass, caplog, script_mode, count): + """Test stopping script.""" + service_started_sem = asyncio.Semaphore(0) + finish_service_event = asyncio.Event() + event = "test_event" + events = async_capture_events(hass, event) + + async def async_simulate_long_service(service): + """Simulate a service that takes a not insignificant time.""" + service_started_sem.release() + await finish_service_event.wait() + + hass.services.async_register("test", "script", async_simulate_long_service) + + sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + + # Get script started specified number of times and wait until the test.script + # service has started for each run. + tasks = [] + for _ in range(count): + hass.async_create_task(script_obj.async_run()) + tasks.append(hass.async_create_task(service_started_sem.acquire())) + await asyncio.wait_for(asyncio.gather(*tasks), 1) + + # Can't assert just yet because we haven't verified stopping works yet. + # If assert fails we can hang test if async_stop doesn't work. + script_was_runing = script_obj.is_running + were_no_events = len(events) == 0 + + # Begin the process of stopping the script (which should stop all runs), and then + # let the service calls complete. + hass.async_create_task(script_obj.async_stop()) + finish_service_event.set() + + await hass.async_block_till_done() + + assert script_was_runing + assert were_no_events + assert not script_obj.is_running + assert len(events) == (count if script_mode == "legacy" else 0) + + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_delay_basic(hass, mock_timeout, script_mode): """Test the delay.""" delay_alias = "delay step" - delay_started_flag = asyncio.Event() + sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + delay_started_flag = async_watch_for_action(script_obj, delay_alias) - @callback - def delay_started_cb(): - delay_started_flag.set() + assert script_obj.can_cancel - delay = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA({"delay": delay, "alias": delay_alias}) + try: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) - for run_mode in _ALL_RUN_MODES: - delay_started_flag.clear() + assert script_obj.is_running + assert script_obj.last_action == delay_alias + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=delay_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=delay_started_cb, run_mode=run_mode - ) - - assert script_obj.can_cancel - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(delay_started_flag.wait(), 1) - - assert script_obj.is_running - assert script_obj.last_action == delay_alias - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + delay - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert script_obj.last_action is None + assert not script_obj.is_running + assert script_obj.last_action is None -async def test_multiple_runs_delay(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_multiple_runs_delay(hass, mock_timeout, script_mode): """Test multiple runs with delay in script.""" event = "test_event" - delay_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def delay_started_cb(): - delay_started_flag.set() - - delay = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + delay = timedelta(seconds=5) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event, "event_data": {"value": 1}}, {"delay": delay}, {"event": event, "event_data": {"value": 2}}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + delay_started_flag = async_watch_for_action(script_obj, "delay") - for run_mode in _ALL_RUN_MODES: - events = [] - delay_started_flag.clear() + try: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=delay_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=delay_started_cb, run_mode=run_mode - ) - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(delay_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 1 - assert events[-1].data["value"] == 1 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - # Start second run of script while first run is in a delay. + assert script_obj.is_running + assert len(events) == 1 + assert events[-1].data["value"] == 1 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + # Start second run of script while first run is in a delay. + if script_mode == "legacy": await script_obj.async_run() - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + delay - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - if run_mode in (None, "legacy"): - assert len(events) == 2 - else: - assert len(events) == 4 - assert events[-3].data["value"] == 1 - assert events[-2].data["value"] == 2 - assert events[-1].data["value"] == 2 - - -async def test_delay_template_ok(hass): - """Test the delay as a template.""" - delay_started_flag = asyncio.Event() - - @callback - def delay_started_cb(): - delay_started_flag.set() - - schema = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 1 }}"}) - - for run_mode in _ALL_RUN_MODES: - delay_started_flag.clear() - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=delay_started_cb) else: - script_obj = script.Script( - hass, schema, change_listener=delay_started_cb, run_mode=run_mode - ) - - assert script_obj.can_cancel - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) + script_obj.sequence[1]["alias"] = "delay run 2" + delay_started_flag = async_watch_for_action(script_obj, "delay run 2") + hass.async_create_task(script_obj.async_run()) await asyncio.wait_for(delay_started_flag.wait(), 1) - assert script_obj.is_running - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise + async_fire_time_changed(hass, dt_util.utcnow() + delay) + await hass.async_block_till_done() + + assert not script_obj.is_running + if script_mode == "legacy": + assert len(events) == 2 else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + timedelta(seconds=1) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running + assert len(events) == 4 + assert events[-3].data["value"] == 1 + assert events[-2].data["value"] == 2 + assert events[-1].data["value"] == 2 -async def test_delay_template_invalid(hass, caplog): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_delay_template_ok(hass, mock_timeout, script_mode): + """Test the delay as a template.""" + sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + delay_started_flag = async_watch_for_action(script_obj, "delay") + + assert script_obj.can_cancel + + try: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) + + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + + assert not script_obj.is_running + + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_delay_template_invalid(hass, caplog, script_mode): """Test the delay as a template that fails.""" event = "test_event" - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event}, {"delay": "{{ invalid_delay }}"}, @@ -525,83 +443,50 @@ async def test_delay_template_invalid(hass, caplog): {"event": event}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + start_idx = len(caplog.records) - for run_mode in _ALL_RUN_MODES: - events = [] + await script_obj.async_run() + await hass.async_block_till_done() - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - start_idx = len(caplog.records) + assert any( + rec.levelname == "ERROR" and "Error rendering" in rec.message + for rec in caplog.records[start_idx:] + ) - await script_obj.async_run() + assert not script_obj.is_running + assert len(events) == 1 + + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_delay_template_complex_ok(hass, mock_timeout, script_mode): + """Test the delay with a working complex template.""" + sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + delay_started_flag = async_watch_for_action(script_obj, "delay") + + assert script_obj.can_cancel + + try: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) await hass.async_block_till_done() - assert any( - rec.levelname == "ERROR" and "Error rendering" in rec.message - for rec in caplog.records[start_idx:] - ) - assert not script_obj.is_running - assert len(events) == 1 -async def test_delay_template_complex_ok(hass): - """Test the delay with a working complex template.""" - delay_started_flag = asyncio.Event() - - @callback - def delay_started_cb(): - delay_started_flag.set() - - milliseconds = 10 - schema = cv.SCRIPT_SCHEMA({"delay": {"milliseconds": "{{ milliseconds }}"}}) - - for run_mode in _ALL_RUN_MODES: - delay_started_flag.clear() - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=delay_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=delay_started_cb, run_mode=run_mode - ) - - assert script_obj.can_cancel - - try: - coro = script_obj.async_run({"milliseconds": milliseconds}) - if run_mode == "background": - await coro - else: - hass.async_create_task(coro) - await asyncio.wait_for(delay_started_flag.wait(), 1) - assert script_obj.is_running - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + timedelta(milliseconds=milliseconds) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - - -async def test_delay_template_complex_invalid(hass, caplog): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_delay_template_complex_invalid(hass, caplog, script_mode): """Test the delay with a complex template that fails.""" event = "test_event" - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event}, {"delay": {"seconds": "{{ invalid_delay }}"}}, @@ -609,543 +494,260 @@ async def test_delay_template_complex_invalid(hass, caplog): {"event": event}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + start_idx = len(caplog.records) - for run_mode in _ALL_RUN_MODES: - events = [] + await script_obj.async_run() + await hass.async_block_till_done() - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - start_idx = len(caplog.records) + assert any( + rec.levelname == "ERROR" and "Error rendering" in rec.message + for rec in caplog.records[start_idx:] + ) - await script_obj.async_run() - await hass.async_block_till_done() + assert not script_obj.is_running + assert len(events) == 1 - assert any( - rec.levelname == "ERROR" and "Error rendering" in rec.message - for rec in caplog.records[start_idx:] - ) + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_cancel_delay(hass, script_mode): + """Test the cancelling while the delay is present.""" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA([{"delay": {"seconds": 5}}, {"event": event}]) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + delay_started_flag = async_watch_for_action(script_obj, "delay") + + try: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + await script_obj.async_stop() assert not script_obj.is_running - assert len(events) == 1 + + # Make sure the script is really stopped. + + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 0 -async def test_cancel_delay(hass): - """Test the cancelling while the delay is present.""" - delay_started_flag = asyncio.Event() - event = "test_event" - - @callback - def delay_started_cb(): - delay_started_flag.set() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - delay = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA([{"delay": delay}, {"event": event}]) - - for run_mode in _ALL_RUN_MODES: - delay_started_flag.clear() - events = [] - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=delay_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=delay_started_cb, run_mode=run_mode - ) - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(delay_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 0 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - await script_obj.async_stop() - - assert not script_obj.is_running - - # Make sure the script is really stopped. - - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + delay - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 0 - - -async def test_wait_template_basic(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_wait_template_basic(hass, script_mode): """Test the wait template.""" wait_alias = "wait step" - wait_started_flag = asyncio.Event() - - @callback - def wait_started_cb(): - wait_started_flag.set() - - schema = cv.SCRIPT_SCHEMA( + sequence = cv.SCRIPT_SCHEMA( { "wait_template": "{{ states.switch.test.state == 'off' }}", "alias": wait_alias, } ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + wait_started_flag = async_watch_for_action(script_obj, wait_alias) - for run_mode in _ALL_RUN_MODES: - wait_started_flag.clear() + assert script_obj.can_cancel + + try: hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) + assert script_obj.is_running + assert script_obj.last_action == wait_alias + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() - assert script_obj.can_cancel - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert script_obj.last_action == wait_alias - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() - - assert not script_obj.is_running - assert script_obj.last_action is None + assert not script_obj.is_running + assert script_obj.last_action is None -async def test_multiple_runs_wait_template(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_multiple_runs_wait_template(hass, script_mode): """Test multiple runs with wait_template in script.""" event = "test_event" - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event, "event_data": {"value": 1}}, {"wait_template": "{{ states.switch.test.state == 'off' }}"}, {"event": event, "event_data": {"value": 2}}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + wait_started_flag = async_watch_for_action(script_obj, "wait") - for run_mode in _ALL_RUN_MODES: - events = [] - wait_started_flag.clear() + try: hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + assert script_obj.is_running + assert len(events) == 1 + assert events[-1].data["value"] == 1 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + # Start second run of script while first run is in wait_template. + if script_mode == "legacy": + await script_obj.async_run() else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) + hass.async_create_task(script_obj.async_run()) + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 1 - assert events[-1].data["value"] == 1 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise + assert not script_obj.is_running + if script_mode == "legacy": + assert len(events) == 2 else: - # Start second run of script while first run is in wait_template. - if run_mode == "blocking": - hass.async_create_task(script_obj.async_run()) - else: - await script_obj.async_run() - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() - - assert not script_obj.is_running - if run_mode in (None, "legacy"): - assert len(events) == 2 - else: - assert len(events) == 4 - assert events[-3].data["value"] == 1 - assert events[-2].data["value"] == 2 - assert events[-1].data["value"] == 2 + assert len(events) == 4 + assert events[-3].data["value"] == 1 + assert events[-2].data["value"] == 2 + assert events[-1].data["value"] == 2 -async def test_cancel_wait_template(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_cancel_wait_template(hass, script_mode): """Test the cancelling while wait_template is present.""" - wait_started_flag = asyncio.Event() event = "test_event" - - @callback - def wait_started_cb(): - wait_started_flag.set() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"wait_template": "{{ states.switch.test.state == 'off' }}"}, {"event": event}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + wait_started_flag = async_watch_for_action(script_obj, "wait") - for run_mode in _ALL_RUN_MODES: - wait_started_flag.clear() - events = [] + try: hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + await script_obj.async_stop() - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) + assert not script_obj.is_running - assert script_obj.is_running - assert len(events) == 0 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - await script_obj.async_stop() + # Make sure the script is really stopped. - assert not script_obj.is_running + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() - # Make sure the script is really stopped. - - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 0 + assert not script_obj.is_running + assert len(events) == 0 -async def test_wait_template_not_schedule(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_wait_template_not_schedule(hass, script_mode): """Test the wait template with correct condition.""" event = "test_event" - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - hass.states.async_set("switch.test", "on") - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event}, {"wait_template": "{{ states.switch.test.state == 'on' }}"}, {"event": event}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - for run_mode in _ALL_RUN_MODES: - events = [] + hass.states.async_set("switch.test", "on") + await script_obj.async_run() + await hass.async_block_till_done() - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) + assert not script_obj.is_running + assert len(events) == 2 - await script_obj.async_run() + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +@pytest.mark.parametrize( + "continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)] +) +async def test_wait_template_timeout( + hass, mock_timeout, continue_on_timeout, n_events, script_mode +): + """Test the wait template, halt on timeout.""" + event = "test_event" + events = async_capture_events(hass, event) + sequence = [ + {"wait_template": "{{ states.switch.test.state == 'off' }}", "timeout": 5}, + {"event": event}, + ] + if continue_on_timeout is not None: + sequence[0]["continue_on_timeout"] = continue_on_timeout + sequence = cv.SCRIPT_SCHEMA(sequence) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + wait_started_flag = async_watch_for_action(script_obj, "wait") + + try: + hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) await hass.async_block_till_done() assert not script_obj.is_running - assert len(events) == 2 + assert len(events) == n_events -async def test_wait_template_timeout_halt(hass): - """Test the wait template, halt on timeout.""" - event = "test_event" - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - timeout = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA( - [ - { - "wait_template": "{{ states.switch.test.state == 'off' }}", - "continue_on_timeout": False, - "timeout": timeout, - }, - {"event": event}, - ] - ) - - for run_mode in _ALL_RUN_MODES: - events = [] - wait_started_flag.clear() - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 0 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + timeout - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 0 - - -async def test_wait_template_timeout_continue(hass): - """Test the wait template with continuing the script.""" - event = "test_event" - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - timeout = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA( - [ - { - "wait_template": "{{ states.switch.test.state == 'off' }}", - "continue_on_timeout": True, - "timeout": timeout, - }, - {"event": event}, - ] - ) - - for run_mode in _ALL_RUN_MODES: - events = [] - wait_started_flag.clear() - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 0 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + timeout - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 1 - - -async def test_wait_template_timeout_default(hass): - """Test the wait template with default continue.""" - event = "test_event" - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - timeout = timedelta(milliseconds=10) - schema = cv.SCRIPT_SCHEMA( - [ - { - "wait_template": "{{ states.switch.test.state == 'off' }}", - "timeout": timeout, - }, - {"event": event}, - ] - ) - - for run_mode in _ALL_RUN_MODES: - events = [] - wait_started_flag.clear() - - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) - - try: - if run_mode == "background": - await script_obj.async_run() - else: - hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 0 - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - if run_mode in (None, "legacy"): - future = dt_util.utcnow() + timeout - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 1 - - -async def test_wait_template_variables(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_wait_template_variables(hass, script_mode): """Test the wait template with variables.""" - wait_started_flag = asyncio.Event() + sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + wait_started_flag = async_watch_for_action(script_obj, "wait") - @callback - def wait_started_cb(): - wait_started_flag.set() + assert script_obj.can_cancel - schema = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"}) - - for run_mode in _ALL_RUN_MODES: - wait_started_flag.clear() + try: hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run({"data": "switch.test"})) + await asyncio.wait_for(wait_started_flag.wait(), 1) - if run_mode is None: - script_obj = script.Script(hass, schema, change_listener=wait_started_cb) - else: - script_obj = script.Script( - hass, schema, change_listener=wait_started_cb, run_mode=run_mode - ) + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() - assert script_obj.can_cancel - - try: - coro = script_obj.async_run({"data": "switch.test"}) - if run_mode == "background": - await coro - else: - hass.async_create_task(coro) - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() - - assert not script_obj.is_running + assert not script_obj.is_running -async def test_condition_basic(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_condition_basic(hass, script_mode): """Test if we can use conditions in a script.""" event = "test_event" - events = [] - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( [ {"event": event}, { @@ -1155,208 +757,127 @@ async def test_condition_basic(hass): {"event": event}, ] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - for run_mode in _ALL_RUN_MODES: - events = [] - hass.states.async_set("test.entity", "hello") - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - assert not script_obj.can_cancel - - await script_obj.async_run() - await hass.async_block_till_done() - - assert len(events) == 2 - - hass.states.async_set("test.entity", "goodbye") - - await script_obj.async_run() - await hass.async_block_till_done() - - assert len(events) == 3 - - -@asynctest.patch("homeassistant.helpers.script.condition.async_from_config") -async def test_condition_created_once(async_from_config, hass): - """Test that the conditions do not get created multiple times.""" - event = "test_event" - events = [] - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) + assert script_obj.can_cancel == (script_mode != "legacy") hass.states.async_set("test.entity", "hello") + await script_obj.async_run() + await hass.async_block_till_done() - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "condition": "template", - "value_template": '{{ states.test.entity.state == "hello" }}', - }, - {"event": event}, - ] - ), + assert len(events) == 2 + + hass.states.async_set("test.entity", "goodbye") + + await script_obj.async_run() + await hass.async_block_till_done() + + assert len(events) == 3 + + +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +@asynctest.patch("homeassistant.helpers.script.condition.async_from_config") +async def test_condition_created_once(async_from_config, hass, script_mode): + """Test that the conditions do not get created multiple times.""" + sequence = cv.SCRIPT_SCHEMA( + { + "condition": "template", + "value_template": '{{ states.test.entity.state == "hello" }}', + } ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) + async_from_config.reset_mock() + + hass.states.async_set("test.entity", "hello") await script_obj.async_run() await script_obj.async_run() await hass.async_block_till_done() - assert async_from_config.call_count == 1 + + async_from_config.assert_called_once() assert len(script_obj._config_cache) == 1 -async def test_condition_all_cached(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_condition_all_cached(hass, script_mode): """Test that multiple conditions get cached.""" - event = "test_event" - events = [] - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) + sequence = cv.SCRIPT_SCHEMA( + [ + { + "condition": "template", + "value_template": '{{ states.test.entity.state == "hello" }}', + }, + { + "condition": "template", + "value_template": '{{ states.test.entity.state != "hello" }}', + }, + ] + ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) hass.states.async_set("test.entity", "hello") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "condition": "template", - "value_template": '{{ states.test.entity.state == "hello" }}', - }, - { - "condition": "template", - "value_template": '{{ states.test.entity.state != "hello" }}', - }, - {"event": event}, - ] - ), - ) - await script_obj.async_run() await hass.async_block_till_done() + assert len(script_obj._config_cache) == 2 -async def test_last_triggered(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_last_triggered(hass, script_mode): """Test the last_triggered.""" event = "test_event" + sequence = cv.SCRIPT_SCHEMA({"event": event}) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - schema = cv.SCRIPT_SCHEMA({"event": event}) + assert script_obj.last_triggered is None - for run_mode in _ALL_RUN_MODES: - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) + time = dt_util.utcnow() + with mock.patch("homeassistant.helpers.script.utcnow", return_value=time): + await script_obj.async_run() + await hass.async_block_till_done() - assert script_obj.last_triggered is None - - time = dt_util.utcnow() - with mock.patch("homeassistant.helpers.script.utcnow", return_value=time): - await script_obj.async_run() - await hass.async_block_till_done() - - assert script_obj.last_triggered == time + assert script_obj.last_triggered == time -async def test_propagate_error_service_not_found(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_propagate_error_service_not_found(hass, script_mode): """Test that a script aborts when a service is not found.""" event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - @callback - def record_event(event): - events.append(event) + with pytest.raises(exceptions.ServiceNotFound): + await script_obj.async_run() - hass.bus.async_listen(event, record_event) - - schema = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) - - run_modes = _ALL_RUN_MODES - if "background" in run_modes: - run_modes.remove("background") - for run_mode in run_modes: - events = [] - - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - with pytest.raises(exceptions.ServiceNotFound): - await script_obj.async_run() - - assert len(events) == 0 - assert not script_obj.is_running + assert len(events) == 0 + assert not script_obj.is_running -async def test_propagate_error_invalid_service_data(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_propagate_error_invalid_service_data(hass, script_mode): """Test that a script aborts when we send invalid service data.""" event = "test_event" - - @callback - def record_event(event): - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - hass.services.async_register( - "test", "script", record_call, schema=vol.Schema({"text": str}) - ) - - schema = cv.SCRIPT_SCHEMA( + events = async_capture_events(hass, event) + calls = async_mock_service(hass, "test", "script", vol.Schema({"text": str})) + sequence = cv.SCRIPT_SCHEMA( [{"service": "test.script", "data": {"text": 1}}, {"event": event}] ) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - run_modes = _ALL_RUN_MODES - if "background" in run_modes: - run_modes.remove("background") - for run_mode in run_modes: - events = [] - calls = [] + with pytest.raises(vol.Invalid): + await script_obj.async_run() - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - with pytest.raises(vol.Invalid): - await script_obj.async_run() - - assert len(events) == 0 - assert len(calls) == 0 - assert not script_obj.is_running + assert len(events) == 0 + assert len(calls) == 0 + assert not script_obj.is_running -async def test_propagate_error_service_exception(hass): +@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) +async def test_propagate_error_service_exception(hass, script_mode): """Test that a script aborts when a service throws an exception.""" event = "test_event" - - @callback - def record_event(event): - events.append(event) - - hass.bus.async_listen(event, record_event) + events = async_capture_events(hass, event) @callback def record_call(service): @@ -1365,24 +886,14 @@ async def test_propagate_error_service_exception(hass): hass.services.async_register("test", "script", record_call) - schema = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) + sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) + script_obj = script.Script(hass, sequence, script_mode=script_mode) - run_modes = _ALL_RUN_MODES - if "background" in run_modes: - run_modes.remove("background") - for run_mode in run_modes: - events = [] + with pytest.raises(ValueError): + await script_obj.async_run() - if run_mode is None: - script_obj = script.Script(hass, schema) - else: - script_obj = script.Script(hass, schema, run_mode=run_mode) - - with pytest.raises(ValueError): - await script_obj.async_run() - - assert len(events) == 0 - assert not script_obj.is_running + assert len(events) == 0 + assert not script_obj.is_running async def test_referenced_entities(): @@ -1441,68 +952,37 @@ async def test_referenced_devices(): assert script_obj.referenced_devices is script_obj.referenced_devices -async def test_if_running_with_legacy_run_mode(hass, caplog): - """Test using if_running with run_mode='legacy'.""" - # TODO: REMOVE - if _ALL_RUN_MODES == [None]: - return - - with pytest.raises(exceptions.HomeAssistantError): - script.Script( - hass, - [], - if_running="ignore", - run_mode="legacy", - logger=logging.getLogger("TEST"), - ) - assert any( - rec.levelname == "ERROR" - and rec.name == "TEST" - and all(text in rec.message for text in ("if_running", "legacy")) - for rec in caplog.records - ) +@contextmanager +def does_not_raise(): + """Indicate no exception is expected.""" + yield -async def test_if_running_ignore(hass, caplog): - """Test overlapping runs with if_running='ignore'.""" - # TODO: REMOVE - if _ALL_RUN_MODES == [None]: - return - +@pytest.mark.parametrize( + "script_mode,expectation,messages", + [ + ("ignore", does_not_raise(), ["Skipping"]), + ("error", pytest.raises(exceptions.HomeAssistantError), []), + ], +) +async def test_script_mode_1(hass, caplog, script_mode, expectation, messages): + """Test overlapping runs with script_mode='ignore'.""" event = "test_event" - events = [] - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, - {"event": event, "event_data": {"value": 2}}, - ] - ), - change_listener=wait_started_cb, - if_running="ignore", - run_mode="background", - logger=logging.getLogger("TEST"), + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] ) + logger = logging.getLogger("TEST") + script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger) + wait_started_flag = async_watch_for_action(script_obj, "wait") try: - await script_obj.async_run() + hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running @@ -1510,85 +990,19 @@ async def test_if_running_ignore(hass, caplog): assert events[0].data["value"] == 1 # Start second run of script while first run is suspended in wait_template. - # This should ignore second run. - await script_obj.async_run() - - assert script_obj.is_running - assert any( - rec.levelname == "INFO" and rec.name == "TEST" and "Skipping" in rec.message - for rec in caplog.records - ) - except (AssertionError, asyncio.TimeoutError): - await script_obj.async_stop() - raise - else: - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(events) == 2 - assert events[1].data["value"] == 2 - - -async def test_if_running_error(hass, caplog): - """Test overlapping runs with if_running='error'.""" - # TODO: REMOVE - if _ALL_RUN_MODES == [None]: - return - - event = "test_event" - events = [] - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, - {"event": event, "event_data": {"value": 2}}, - ] - ), - change_listener=wait_started_cb, - if_running="error", - run_mode="background", - logger=logging.getLogger("TEST"), - ) - - try: - await script_obj.async_run() - await asyncio.wait_for(wait_started_flag.wait(), 1) - - assert script_obj.is_running - assert len(events) == 1 - assert events[0].data["value"] == 1 - - # Start second run of script while first run is suspended in wait_template. - # This should cause an error. - - with pytest.raises(exceptions.HomeAssistantError): + with expectation: await script_obj.async_run() assert script_obj.is_running - assert any( - rec.levelname == "ERROR" - and rec.name == "TEST" - and "Already running" in rec.message - for rec in caplog.records + assert all( + any( + rec.levelname == "INFO" + and rec.name == "TEST" + and message in rec.message + for rec in caplog.records + ) + for message in messages ) except (AssertionError, asyncio.TimeoutError): await script_obj.async_stop() @@ -1602,46 +1016,28 @@ async def test_if_running_error(hass, caplog): assert events[1].data["value"] == 2 -async def test_if_running_restart(hass, caplog): - """Test overlapping runs with if_running='restart'.""" - # TODO: REMOVE - if _ALL_RUN_MODES == [None]: - return - +@pytest.mark.parametrize( + "script_mode,messages,last_events", + [("restart", ["Restarting"], [2]), ("parallel", [], [2, 2])], +) +async def test_script_mode_2(hass, caplog, script_mode, messages, last_events): + """Test overlapping runs with script_mode='restart'.""" event = "test_event" - events = [] - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, - {"event": event, "event_data": {"value": 2}}, - ] - ), - change_listener=wait_started_cb, - if_running="restart", - run_mode="background", - logger=logging.getLogger("TEST"), + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] ) + logger = logging.getLogger("TEST") + script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger) + wait_started_flag = async_watch_for_action(script_obj, "wait") try: - await script_obj.async_run() + hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running @@ -1652,17 +1048,20 @@ async def test_if_running_restart(hass, caplog): # This should stop first run then start a new run. wait_started_flag.clear() - await script_obj.async_run() + hass.async_create_task(script_obj.async_run()) await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running assert len(events) == 2 assert events[1].data["value"] == 1 - assert any( - rec.levelname == "INFO" - and rec.name == "TEST" - and "Restarting" in rec.message - for rec in caplog.records + assert all( + any( + rec.levelname == "INFO" + and rec.name == "TEST" + and message in rec.message + for rec in caplog.records + ) + for message in messages ) except (AssertionError, asyncio.TimeoutError): await script_obj.async_stop() @@ -1672,50 +1071,30 @@ async def test_if_running_restart(hass, caplog): await hass.async_block_till_done() assert not script_obj.is_running - assert len(events) == 3 - assert events[2].data["value"] == 2 + assert len(events) == 2 + len(last_events) + for idx, value in enumerate(last_events, start=2): + assert events[idx].data["value"] == value -async def test_if_running_parallel(hass): - """Test overlapping runs with if_running='parallel'.""" - # TODO: REMOVE - if _ALL_RUN_MODES == [None]: - return - +async def test_script_mode_queue(hass): + """Test overlapping runs with script_mode='queue'.""" event = "test_event" - events = [] - wait_started_flag = asyncio.Event() - - @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) - - hass.bus.async_listen(event, record_event) - - @callback - def wait_started_cb(): - wait_started_flag.set() - - hass.states.async_set("switch.test", "on") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, - {"event": event, "event_data": {"value": 2}}, - ] - ), - change_listener=wait_started_cb, - if_running="parallel", - run_mode="background", - logger=logging.getLogger("TEST"), + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + {"wait_template": "{{ states.switch.test.state == 'on' }}"}, + ] ) + logger = logging.getLogger("TEST") + script_obj = script.Script(hass, sequence, script_mode="queue", logger=logger) + wait_started_flag = async_watch_for_action(script_obj, "wait") try: - await script_obj.async_run() + hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run()) await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running @@ -1723,23 +1102,39 @@ async def test_if_running_parallel(hass): assert events[0].data["value"] == 1 # Start second run of script while first run is suspended in wait_template. - # This should start a new, independent run. + # This second run should not start until the first run has finished. + + hass.async_create_task(script_obj.async_run()) + + await asyncio.sleep(0) + assert script_obj.is_running + assert len(events) == 1 wait_started_flag.clear() - await script_obj.async_run() + hass.states.async_set("switch.test", "off") await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running assert len(events) == 2 - assert events[1].data["value"] == 1 + assert events[1].data["value"] == 2 + + wait_started_flag.clear() + hass.states.async_set("switch.test", "on") + await asyncio.wait_for(wait_started_flag.wait(), 1) + + await asyncio.sleep(0) + assert script_obj.is_running + assert len(events) == 3 + assert events[2].data["value"] == 1 except (AssertionError, asyncio.TimeoutError): await script_obj.async_stop() raise else: hass.states.async_set("switch.test", "off") + await asyncio.sleep(0) + hass.states.async_set("switch.test", "on") await hass.async_block_till_done() assert not script_obj.is_running assert len(events) == 4 - assert events[2].data["value"] == 2 assert events[3].data["value"] == 2