diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index b1d6c24d303..6f06eeb0094 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -393,10 +393,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity): try: await self.action_script.async_run(variables, trigger_context) - except Exception as err: # pylint: disable=broad-except - self.action_script.async_log_exception( - _LOGGER, f"Error while executing automation {self.entity_id}", err - ) + except Exception: # pylint: disable=broad-except + pass self._last_triggered = utcnow() await self.async_update_ha_state() @@ -504,7 +502,9 @@ async def _async_process_config(hass, config, component): hidden = config_block[CONF_HIDE_ENTITY] initial_state = config_block.get(CONF_INITIAL_STATE) - action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name) + action_script = script.Script( + hass, config_block.get(CONF_ACTION, {}), name, logger=_LOGGER + ) if CONF_CONDITION in config_block: cond_func = await _async_process_if(hass, config, config_block) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 0a7b8596248..9384c58db81 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -242,7 +242,9 @@ class ScriptEntity(ToggleEntity): self.object_id = object_id self.icon = icon self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self.script = Script(hass, sequence, name, self.async_update_ha_state) + self.script = Script( + hass, sequence, name, self.async_update_ha_state, logger=_LOGGER + ) @property def should_poll(self): @@ -279,22 +281,15 @@ class ScriptEntity(ToggleEntity): {ATTR_NAME: self.script.name, ATTR_ENTITY_ID: self.entity_id}, context=context, ) - try: - await self.script.async_run(kwargs.get(ATTR_VARIABLES), context) - except Exception as err: - self.script.async_log_exception( - _LOGGER, f"Error executing script {self.entity_id}", err - ) - raise err + await self.script.async_run(kwargs.get(ATTR_VARIABLES), context) async def async_turn_off(self, **kwargs): """Turn script off.""" - self.script.async_stop() + await self.script.async_stop() async def async_will_remove_from_hass(self): """Stop script and remove service when it will be removed from Home Assistant.""" - if self.script.is_running: - self.script.async_stop() + await self.script.async_stop() # remove service self.hass.services.async_remove(DOMAIN, self.object_id) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1cac4679d82..1ce9d2b87bb 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1,10 +1,11 @@ """Helpers to execute scripts.""" +from abc import ABC, abstractmethod import asyncio from contextlib import suppress from datetime import datetime from itertools import islice import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast import voluptuous as vol @@ -31,13 +32,10 @@ from homeassistant.helpers.event import ( async_track_template, ) from homeassistant.helpers.typing import ConfigType -from homeassistant.util.async_ import run_callback_threadsafe -import homeassistant.util.dt as date_util +from homeassistant.util.dt import utcnow # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs -_LOGGER = logging.getLogger(__name__) - CONF_ALIAS = "alias" CONF_SERVICE = "service" CONF_SERVICE_DATA = "data" @@ -50,7 +48,6 @@ CONF_WAIT_TEMPLATE = "wait_template" CONF_CONTINUE = "continue_on_timeout" CONF_SCENE = "scene" - ACTION_DELAY = "delay" ACTION_WAIT_TEMPLATE = "wait_template" ACTION_CHECK_CONDITION = "condition" @@ -59,6 +56,31 @@ ACTION_CALL_SERVICE = "call_service" ACTION_DEVICE_AUTOMATION = "device" ACTION_ACTIVATE_SCENE = "scene" +IF_RUNNING_ERROR = "error" +IF_RUNNING_IGNORE = "ignore" +IF_RUNNING_PARALLEL = "parallel" +IF_RUNNING_RESTART = "restart" +# First choice is default +IF_RUNNING_CHOICES = [ + IF_RUNNING_PARALLEL, + IF_RUNNING_ERROR, + IF_RUNNING_IGNORE, + IF_RUNNING_RESTART, +] + +RUN_MODE_BACKGROUND = "background" +RUN_MODE_BLOCKING = "blocking" +RUN_MODE_LEGACY = "legacy" +# First choice is default +RUN_MODE_CHOICES = [ + RUN_MODE_BLOCKING, + RUN_MODE_BACKGROUND, + RUN_MODE_LEGACY, +] + +_LOG_EXCEPTION = logging.ERROR + 1 +_TIMEOUT_MSG = "Timeout reached, abort script." + def _determine_action(action): """Determine action type.""" @@ -83,16 +105,6 @@ def _determine_action(action): return ACTION_CALL_SERVICE -def call_from_config( - hass: HomeAssistant, - config: ConfigType, - variables: Optional[Sequence] = None, - context: Optional[Context] = None, -) -> None: - """Call a script based on a config entry.""" - Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context) - - async def async_validate_action_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: @@ -121,6 +133,446 @@ class _SuspendScript(Exception): """Throw if script needs to suspend.""" +class _ScriptRunBase(ABC): + """Common data & methods for managing Script sequence run.""" + + def __init__( + self, + hass: HomeAssistant, + script: "Script", + variables: Optional[Sequence], + context: Optional[Context], + log_exceptions: bool, + ) -> None: + self._hass = hass + self._script = script + self._variables = variables + self._context = context + self._log_exceptions = log_exceptions + self._step = -1 + self._action: Optional[Dict[str, Any]] = None + + def _changed(self): + self._script._changed() # pylint: disable=protected-access + + @property + def _config_cache(self): + return self._script._config_cache # pylint: disable=protected-access + + @abstractmethod + async def async_run(self) -> None: + """Run script.""" + + async def _async_step(self, log_exceptions): + try: + await getattr(self, f"_async_{_determine_action(self._action)}_step")() + except Exception as err: + if not isinstance(err, (_SuspendScript, _StopScript)) and ( + self._log_exceptions or log_exceptions + ): + self._log_exception(err) + raise + + @abstractmethod + async def async_stop(self) -> None: + """Stop script run.""" + + def _log_exception(self, exception): + action_type = _determine_action(self._action) + + error = str(exception) + level = logging.ERROR + + if isinstance(exception, vol.Invalid): + error_desc = "Invalid data" + + elif isinstance(exception, exceptions.TemplateError): + error_desc = "Error rendering template" + + elif isinstance(exception, exceptions.Unauthorized): + error_desc = "Unauthorized" + + elif isinstance(exception, exceptions.ServiceNotFound): + error_desc = "Service not found" + + else: + error_desc = "Unexpected error" + level = _LOG_EXCEPTION + + self._log( + "Error executing script. %s for %s at pos %s: %s", + error_desc, + action_type, + self._step + 1, + error, + level=level, + ) + + @abstractmethod + async def _async_delay_step(self): + """Handle delay.""" + + def _prep_delay_step(self): + try: + delay = vol.All(cv.time_period, cv.positive_timedelta)( + template.render_complex(self._action[CONF_DELAY], self._variables) + ) + except (exceptions.TemplateError, vol.Invalid) as ex: + self._raise( + "Error rendering %s delay template: %s", + self._script.name, + ex, + exception=_StopScript, + ) + + self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}") + self._log("Executing step %s", self._script.last_action) + + return delay + + @abstractmethod + async def _async_wait_template_step(self): + """Handle a wait template.""" + + def _prep_wait_template_step(self, async_script_wait): + wait_template = self._action[CONF_WAIT_TEMPLATE] + wait_template.hass = self._hass + + self._script.last_action = self._action.get(CONF_ALIAS, "wait template") + self._log("Executing step %s", self._script.last_action) + + # check if condition already okay + if condition.async_template(self._hass, wait_template, self._variables): + return None + + return async_track_template( + self._hass, wait_template, async_script_wait, self._variables + ) + + async def _async_call_service_step(self): + """Call the service specified in the action.""" + self._script.last_action = self._action.get(CONF_ALIAS, "call service") + self._log("Executing step %s", self._script.last_action) + await service.async_call_from_config( + self._hass, + self._action, + blocking=True, + variables=self._variables, + validate_config=False, + context=self._context, + ) + + async def _async_device_step(self): + """Perform the device automation specified in the action.""" + self._script.last_action = self._action.get(CONF_ALIAS, "device automation") + self._log("Executing step %s", self._script.last_action) + platform = await device_automation.async_get_device_automation_platform( + self._hass, self._action[CONF_DOMAIN], "action" + ) + await platform.async_call_action_from_config( + self._hass, self._action, self._variables, self._context + ) + + async def _async_scene_step(self): + """Activate the scene specified in the action.""" + self._script.last_action = self._action.get(CONF_ALIAS, "activate scene") + self._log("Executing step %s", self._script.last_action) + await self._hass.services.async_call( + scene.DOMAIN, + SERVICE_TURN_ON, + {ATTR_ENTITY_ID: self._action[CONF_SCENE]}, + blocking=True, + context=self._context, + ) + + async def _async_event_step(self): + """Fire an event.""" + self._script.last_action = self._action.get( + CONF_ALIAS, self._action[CONF_EVENT] + ) + self._log("Executing step %s", self._script.last_action) + event_data = dict(self._action.get(CONF_EVENT_DATA, {})) + if CONF_EVENT_DATA_TEMPLATE in self._action: + try: + event_data.update( + template.render_complex( + self._action[CONF_EVENT_DATA_TEMPLATE], self._variables + ) + ) + except exceptions.TemplateError as ex: + self._log( + "Error rendering event data template: %s", ex, level=logging.ERROR + ) + + self._hass.bus.async_fire( + self._action[CONF_EVENT], event_data, context=self._context + ) + + async def _async_condition_step(self): + """Test if condition is matching.""" + config_cache_key = frozenset((k, str(v)) for k, v in self._action.items()) + config = self._config_cache.get(config_cache_key) + if not config: + config = await condition.async_from_config(self._hass, self._action, False) + self._config_cache[config_cache_key] = config + + self._script.last_action = self._action.get( + CONF_ALIAS, self._action[CONF_CONDITION] + ) + check = config(self._hass, self._variables) + self._log("Test condition %s: %s", self._script.last_action, check) + if not check: + raise _StopScript + + def _log(self, msg, *args, level=logging.INFO): + self._script._log(msg, *args, level=level) # pylint: disable=protected-access + + def _raise(self, msg, *args, exception=None): + # pylint: disable=protected-access + self._script._raise(msg, *args, exception=exception) + + +class _ScriptRun(_ScriptRunBase): + """Manage Script sequence run.""" + + def __init__( + self, + hass: HomeAssistant, + script: "Script", + variables: Optional[Sequence], + context: Optional[Context], + log_exceptions: bool, + ) -> None: + super().__init__(hass, script, variables, context, log_exceptions) + self._stop = asyncio.Event() + self._stopped = asyncio.Event() + + async def _async_run(self, propagate_exceptions=True): + self._log("Running script") + try: + for self._step, self._action in enumerate(self._script.sequence): + if self._stop.is_set(): + break + await self._async_step(not propagate_exceptions) + except _StopScript: + pass + except Exception: # pylint: disable=broad-except + if propagate_exceptions: + raise + finally: + if not self._stop.is_set(): + self._changed() + self._script.last_action = None + self._script._runs.remove(self) # pylint: disable=protected-access + self._stopped.set() + + async def async_stop(self) -> None: + """Stop script run.""" + self._stop.set() + await self._stopped.wait() + + async def _async_delay_step(self): + """Handle delay.""" + timeout = self._prep_delay_step().total_seconds() + if not self._stop.is_set(): + self._changed() + await asyncio.wait({self._stop.wait()}, timeout=timeout) + + async def _async_wait_template_step(self): + """Handle a wait template.""" + + @callback + def async_script_wait(entity_id, from_s, to_s): + """Handle script after template condition is true.""" + done.set() + + unsub = self._prep_wait_template_step(async_script_wait) + if not unsub: + return + + if not self._stop.is_set(): + self._changed() + try: + timeout = self._action[CONF_TIMEOUT].total_seconds() + except KeyError: + timeout = None + done = asyncio.Event() + try: + await asyncio.wait_for( + asyncio.wait( + {self._stop.wait(), done.wait()}, + return_when=asyncio.FIRST_COMPLETED, + ), + timeout, + ) + except asyncio.TimeoutError: + if not self._action.get(CONF_CONTINUE, True): + self._log(_TIMEOUT_MSG) + raise _StopScript + finally: + unsub() + + +class _BackgroundScriptRun(_ScriptRun): + """Manage background Script sequence run.""" + + async def async_run(self) -> None: + """Run script.""" + self._hass.async_create_task(self._async_run(False)) + + +class _BlockingScriptRun(_ScriptRun): + """Manage blocking Script sequence run.""" + + async def async_run(self) -> None: + """Run script.""" + try: + await asyncio.shield(self._async_run()) + except asyncio.CancelledError: + await self.async_stop() + raise + + +class _LegacyScriptRun(_ScriptRunBase): + """Manage legacy Script sequence run.""" + + def __init__( + self, + hass: HomeAssistant, + script: "Script", + variables: Optional[Sequence], + context: Optional[Context], + log_exceptions: bool, + shared: Optional["_LegacyScriptRun"], + ) -> None: + super().__init__(hass, script, variables, context, log_exceptions) + if shared: + self._shared = shared + else: + # To implement legacy behavior we need to share the following "run state" + # amongst all runs, so it will only exist in the first instantiation of + # concurrent runs, and the rest will use it, too. + self._current = -1 + self._async_listeners: List[CALLBACK_TYPE] = [] + self._shared = self + + @property + def _cur(self): + return self._shared._current # pylint: disable=protected-access + + @_cur.setter + def _cur(self, value): + self._shared._current = value # pylint: disable=protected-access + + @property + def _async_listener(self): + return self._shared._async_listeners # pylint: disable=protected-access + + async def async_run(self) -> None: + """Run script.""" + await self._async_run() + + async def _async_run(self, propagate_exceptions=True): + if self._cur == -1: + self._log("Running script") + self._cur = 0 + + # Unregister callback if we were in a delay or wait but turn on is + # called again. In that case we just continue execution. + self._async_remove_listener() + + suspended = False + try: + for self._step, self._action in islice( + enumerate(self._script.sequence), self._cur, None + ): + await self._async_step(not propagate_exceptions) + except _StopScript: + pass + except _SuspendScript: + # Store next step to take and notify change listeners + self._cur = self._step + 1 + suspended = True + return + except Exception: # pylint: disable=broad-except + if propagate_exceptions: + raise + finally: + if self._cur != -1: + self._changed() + if not suspended: + self._script.last_action = None + await self.async_stop() + + async def async_stop(self) -> None: + """Stop script run.""" + if self._cur == -1: + return + + self._cur = -1 + self._async_remove_listener() + self._script._runs.clear() # pylint: disable=protected-access + + async def _async_delay_step(self): + """Handle delay.""" + delay = self._prep_delay_step() + + @callback + def async_script_delay(now): + """Handle delay.""" + with suppress(ValueError): + self._async_listener.remove(unsub) + self._hass.async_create_task(self._async_run(False)) + + unsub = async_track_point_in_utc_time( + self._hass, async_script_delay, utcnow() + delay + ) + self._async_listener.append(unsub) + raise _SuspendScript + + async def _async_wait_template_step(self): + """Handle a wait template.""" + + @callback + def async_script_wait(entity_id, from_s, to_s): + """Handle script after template condition is true.""" + self._async_remove_listener() + self._hass.async_create_task(self._async_run(False)) + + @callback + def async_script_timeout(now): + """Call after timeout is retrieve.""" + with suppress(ValueError): + self._async_listener.remove(unsub) + + # Check if we want to continue to execute + # the script after the timeout + if self._action.get(CONF_CONTINUE, True): + self._hass.async_create_task(self._async_run(False)) + else: + self._log(_TIMEOUT_MSG) + self._hass.async_create_task(self.async_stop()) + + unsub_wait = self._prep_wait_template_step(async_script_wait) + if not unsub_wait: + return + self._async_listener.append(unsub_wait) + + if CONF_TIMEOUT in self._action: + unsub = async_track_point_in_utc_time( + self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT] + ) + self._async_listener.append(unsub) + + raise _SuspendScript + + def _async_remove_listener(self): + """Remove listeners, if any.""" + for unsub in self._async_listener: + unsub() + self._async_listener.clear() + + class Script: """Representation of a script.""" @@ -130,39 +582,46 @@ class Script: sequence: Sequence[Dict[str, Any]], name: Optional[str] = None, change_listener: Optional[Callable[..., Any]] = None, + if_running: Optional[str] = None, + run_mode: Optional[str] = None, + logger: Optional[logging.Logger] = None, + log_exceptions: bool = True, ) -> None: """Initialize the script.""" - self.hass = hass + self._logger = logger or logging.getLogger(__name__) + self._hass = hass self.sequence = sequence template.attach(hass, self.sequence) self.name = name self._change_listener = change_listener - self._cur = -1 - self._exception_step: Optional[int] = None self.last_action = None self.last_triggered: Optional[datetime] = None self.can_cancel = any( CONF_DELAY in action or CONF_WAIT_TEMPLATE in action for action in self.sequence ) - self._async_listener: List[CALLBACK_TYPE] = [] + if not if_running and not run_mode: + self._if_running = IF_RUNNING_PARALLEL + self._run_mode = RUN_MODE_LEGACY + elif if_running and run_mode == RUN_MODE_LEGACY: + self._raise('Cannot use if_running if run_mode is "legacy"') + else: + self._if_running = if_running or IF_RUNNING_CHOICES[0] + self._run_mode = run_mode or RUN_MODE_CHOICES[0] + self._runs: List[_ScriptRunBase] = [] + self._log_exceptions = log_exceptions self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} - self._actions = { - ACTION_DELAY: self._async_delay, - ACTION_WAIT_TEMPLATE: self._async_wait_template, - ACTION_CHECK_CONDITION: self._async_check_condition, - ACTION_FIRE_EVENT: self._async_fire_event, - ACTION_CALL_SERVICE: self._async_call_service, - ACTION_DEVICE_AUTOMATION: self._async_device_automation, - ACTION_ACTIVATE_SCENE: self._async_activate_scene, - } self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None + def _changed(self): + if self._change_listener: + self._hass.async_add_job(self._change_listener) + @property def is_running(self) -> bool: """Return true if script is on.""" - return self._cur != -1 + return len(self._runs) > 0 @property def referenced_devices(self): @@ -223,288 +682,62 @@ class Script: def run(self, variables=None, context=None): """Run script.""" asyncio.run_coroutine_threadsafe( - self.async_run(variables, context), self.hass.loop + self.async_run(variables, context), self._hass.loop ).result() async def async_run( self, variables: Optional[Sequence] = None, context: Optional[Context] = None ) -> None: - """Run script. - - This method is a coroutine. - """ - self.last_triggered = date_util.utcnow() - if self._cur == -1: - self._log("Running script") - self._cur = 0 - - # Unregister callback if we were in a delay or wait but turn on is - # called again. In that case we just continue execution. - self._async_remove_listener() - - for cur, action in islice(enumerate(self.sequence), self._cur, None): - try: - await self._handle_action(action, variables, context) - except _SuspendScript: - # Store next step to take and notify change listeners - self._cur = cur + 1 - if self._change_listener: - self.hass.async_add_job(self._change_listener) + """Run script.""" + if self.is_running: + if self._if_running == IF_RUNNING_IGNORE: + self._log("Skipping script") return - except _StopScript: - break - except Exception: - # Store the step that had an exception - self._exception_step = cur - # Set script to not running - self._cur = -1 - self.last_action = None - # Pass exception on. - raise - # Set script to not-running. - self._cur = -1 - self.last_action = None - if self._change_listener: - self.hass.async_add_job(self._change_listener) + if self._if_running == IF_RUNNING_ERROR: + self._raise("Already running") + if self._if_running == IF_RUNNING_RESTART: + self._log("Restarting script") + await self.async_stop() - def stop(self) -> None: - """Stop running script.""" - run_callback_threadsafe(self.hass.loop, self.async_stop).result() - - @callback - def async_stop(self) -> None: - """Stop running script.""" - if self._cur == -1: - return - - self._cur = -1 - self._async_remove_listener() - if self._change_listener: - self.hass.async_add_job(self._change_listener) - - @callback - def async_log_exception(self, logger, message_base, exception): - """Log an exception for this script. - - Should only be called on exceptions raised by this scripts async_run. - """ - step = self._exception_step - action = self.sequence[step] - action_type = _determine_action(action) - - error = None - meth = logger.error - - if isinstance(exception, vol.Invalid): - error_desc = "Invalid data" - - elif isinstance(exception, exceptions.TemplateError): - error_desc = "Error rendering template" - - elif isinstance(exception, exceptions.Unauthorized): - error_desc = "Unauthorized" - - elif isinstance(exception, exceptions.ServiceNotFound): - error_desc = "Service not found" - - else: - # Print the full stack trace, unknown error - error_desc = "Unknown error" - meth = logger.exception - error = "" - - if error is None: - error = str(exception) - - meth( - "%s. %s for %s at pos %s: %s", - message_base, - error_desc, - action_type, - step + 1, - error, - ) - - async def _handle_action(self, action, variables, context): - """Handle an action.""" - await self._actions[_determine_action(action)](action, variables, context) - - async def _async_delay(self, action, variables, context): - """Handle delay.""" - # Call ourselves in the future to continue work - unsub = None - - @callback - def async_script_delay(now): - """Handle delay.""" - with suppress(ValueError): - self._async_listener.remove(unsub) - - self.hass.async_create_task(self.async_run(variables, context)) - - delay = action[CONF_DELAY] - - try: - if isinstance(delay, template.Template): - delay = vol.All(cv.time_period, cv.positive_timedelta)( - delay.async_render(variables) - ) - elif isinstance(delay, dict): - delay_data = {} - delay_data.update(template.render_complex(delay, variables)) - delay = cv.time_period(delay_data) - except (exceptions.TemplateError, vol.Invalid) as ex: - _LOGGER.error("Error rendering '%s' delay template: %s", self.name, ex) - raise _StopScript - - self.last_action = action.get(CONF_ALIAS, f"delay {delay}") - self._log("Executing step %s" % self.last_action) - - unsub = async_track_point_in_utc_time( - self.hass, async_script_delay, date_util.utcnow() + delay - ) - self._async_listener.append(unsub) - raise _SuspendScript - - async def _async_wait_template(self, action, variables, context): - """Handle a wait template.""" - # Call ourselves in the future to continue work - wait_template = action[CONF_WAIT_TEMPLATE] - wait_template.hass = self.hass - - self.last_action = action.get(CONF_ALIAS, "wait template") - self._log("Executing step %s" % self.last_action) - - # check if condition already okay - if condition.async_template(self.hass, wait_template, variables): - return - - @callback - def async_script_wait(entity_id, from_s, to_s): - """Handle script after template condition is true.""" - self._async_remove_listener() - self.hass.async_create_task(self.async_run(variables, context)) - - self._async_listener.append( - async_track_template(self.hass, wait_template, async_script_wait, variables) - ) - - if CONF_TIMEOUT in action: - self._async_set_timeout( - action, variables, context, action.get(CONF_CONTINUE, True) - ) - - raise _SuspendScript - - async def _async_call_service(self, action, variables, context): - """Call the service specified in the action. - - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "call service") - self._log("Executing step %s" % self.last_action) - await service.async_call_from_config( - self.hass, - action, - blocking=True, - variables=variables, - validate_config=False, - context=context, - ) - - async def _async_device_automation(self, action, variables, context): - """Perform the device automation specified in the action. - - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "device automation") - self._log("Executing step %s" % self.last_action) - platform = await device_automation.async_get_device_automation_platform( - self.hass, action[CONF_DOMAIN], "action" - ) - await platform.async_call_action_from_config( - self.hass, action, variables, context - ) - - async def _async_activate_scene(self, action, variables, context): - """Activate the scene specified in the action. - - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "activate scene") - self._log("Executing step %s" % self.last_action) - await self.hass.services.async_call( - scene.DOMAIN, - SERVICE_TURN_ON, - {ATTR_ENTITY_ID: action[CONF_SCENE]}, - blocking=True, - context=context, - ) - - async def _async_fire_event(self, action, variables, context): - """Fire an event.""" - self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) - self._log("Executing step %s" % self.last_action) - event_data = dict(action.get(CONF_EVENT_DATA, {})) - if CONF_EVENT_DATA_TEMPLATE in action: - try: - event_data.update( - template.render_complex(action[CONF_EVENT_DATA_TEMPLATE], variables) - ) - except exceptions.TemplateError as ex: - _LOGGER.error("Error rendering event data template: %s", ex) - - self.hass.bus.async_fire(action[CONF_EVENT], event_data, context=context) - - async def _async_check_condition(self, action, variables, context): - """Test if condition is matching.""" - config_cache_key = frozenset((k, str(v)) for k, v in action.items()) - config = self._config_cache.get(config_cache_key) - if not config: - config = await condition.async_from_config(self.hass, action, False) - self._config_cache[config_cache_key] = config - - self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) - check = config(self.hass, variables) - self._log(f"Test condition {self.last_action}: {check}") - - if not check: - raise _StopScript - - def _async_set_timeout(self, action, variables, context, continue_on_timeout): - """Schedule a timeout to abort or continue script.""" - timeout = action[CONF_TIMEOUT] - unsub = None - - @callback - def async_script_timeout(now): - """Call after timeout is retrieve.""" - with suppress(ValueError): - self._async_listener.remove(unsub) - - # Check if we want to continue to execute - # the script after the timeout - if continue_on_timeout: - self.hass.async_create_task(self.async_run(variables, context)) + self.last_triggered = utcnow() + if self._run_mode == RUN_MODE_LEGACY: + if self._runs: + shared = cast(Optional[_LegacyScriptRun], self._runs[0]) else: - self._log("Timeout reached, abort script.") - self.async_stop() + shared = None + run: _ScriptRunBase = _LegacyScriptRun( + self._hass, self, variables, context, self._log_exceptions, shared + ) + else: + if self._run_mode == RUN_MODE_BACKGROUND: + run = _BackgroundScriptRun( + self._hass, self, variables, context, self._log_exceptions + ) + else: + run = _BlockingScriptRun( + self._hass, self, variables, context, self._log_exceptions + ) + self._runs.append(run) + await run.async_run() - unsub = async_track_point_in_utc_time( - self.hass, async_script_timeout, date_util.utcnow() + timeout - ) - self._async_listener.append(unsub) + async def async_stop(self) -> None: + """Stop running script.""" + if not self.is_running: + return + await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs))) + self._changed() - def _async_remove_listener(self): - """Remove point in time listener, if any.""" - for unsub in self._async_listener: - unsub() - self._async_listener.clear() + def _log(self, msg, *args, level=logging.INFO): + if self.name: + msg = f"{self.name}: {msg}" + if level == _LOG_EXCEPTION: + self._logger.exception(msg, *args) + else: + self._logger.log(level, msg, *args) - def _log(self, msg): - """Logger helper.""" - if self.name is not None: - msg = f"Script {self.name}: {msg}" - - _LOGGER.info(msg) + def _raise(self, msg, *args, exception=None): + if not exception: + exception = exceptions.HomeAssistantError + self._log(msg, *args, level=logging.ERROR) + raise exception(msg % args) diff --git a/tests/components/demo/test_notify.py b/tests/components/demo/test_notify.py index 30fb49be47d..e30d65112e8 100644 --- a/tests/components/demo/test_notify.py +++ b/tests/components/demo/test_notify.py @@ -8,7 +8,7 @@ import voluptuous as vol import homeassistant.components.demo.notify as demo import homeassistant.components.notify as notify from homeassistant.core import callback -from homeassistant.helpers import discovery, script +from homeassistant.helpers import discovery from homeassistant.setup import setup_component from tests.common import assert_setup_component, get_test_home_assistant @@ -121,7 +121,7 @@ class TestNotifyDemo(unittest.TestCase): def test_calling_notify_from_script_loaded_from_yaml_without_title(self): """Test if we can call a notify from a script.""" self._setup_notify() - conf = { + step = { "service": "notify.notify", "data": { "data": { @@ -130,8 +130,8 @@ class TestNotifyDemo(unittest.TestCase): }, "data_template": {"message": "Test 123 {{ 2 + 2 }}\n"}, } - - script.call_from_config(self.hass, conf) + setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}}) + self.hass.services.call("script", "test") self.hass.block_till_done() assert len(self.events) == 1 assert { @@ -144,7 +144,7 @@ class TestNotifyDemo(unittest.TestCase): def test_calling_notify_from_script_loaded_from_yaml_with_title(self): """Test if we can call a notify from a script.""" self._setup_notify() - conf = { + step = { "service": "notify.notify", "data": { "data": { @@ -153,8 +153,8 @@ class TestNotifyDemo(unittest.TestCase): }, "data_template": {"message": "Test 123 {{ 2 + 2 }}\n", "title": "Test"}, } - - script.call_from_config(self.hass, conf) + setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}}) + self.hass.services.call("script", "test") self.hass.block_till_done() assert len(self.events) == 1 assert { diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5e748e3adfe..443b131b2aa 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1,11 +1,11 @@ """The tests for the Script component.""" # pylint: disable=protected-access +import asyncio from datetime import timedelta -import functools as ft +import logging from unittest import mock import asynctest -import jinja2 import pytest import voluptuous as vol @@ -21,80 +21,94 @@ from tests.common import async_fire_time_changed ENTITY_ID = "script.test" +_ALL_RUN_MODES = [None, "background", "blocking"] -async def test_firing_event(hass): + +async def test_firing_event_basic(hass): """Test the firing of events.""" event = "test_event" context = Context() - calls = [] @callback def record_event(event): """Add recorded event to set.""" - calls.append(event) + events.append(event) hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}}) - ) + schema = cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}}) - await script_obj.async_run(context=context) + # For this one test we'll make sure "legacy" works the same as None. + for run_mode in _ALL_RUN_MODES + ["legacy"]: + events = [] - await hass.async_block_till_done() + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get("hello") == "world" - assert not script_obj.can_cancel + assert not script_obj.can_cancel + + await script_obj.async_run(context=context) + + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].context is context + assert events[0].data.get("hello") == "world" + assert not script_obj.can_cancel async def test_firing_event_template(hass): """Test the firing of events.""" event = "test_event" context = Context() - calls = [] @callback def record_event(event): """Add recorded event to set.""" - calls.append(event) + events.append(event) hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - { - "event": event, - "event_data_template": { - "dict": { - 1: "{{ is_world }}", - 2: "{{ is_world }}{{ is_world }}", - 3: "{{ is_world }}{{ is_world }}{{ is_world }}", - }, - "list": ["{{ is_world }}", "{{ is_world }}{{ is_world }}"], + schema = cv.SCRIPT_SCHEMA( + { + "event": event, + "event_data_template": { + "dict": { + 1: "{{ is_world }}", + 2: "{{ is_world }}{{ is_world }}", + 3: "{{ is_world }}{{ is_world }}{{ is_world }}", }, - } - ), + "list": ["{{ is_world }}", "{{ is_world }}{{ is_world }}"], + }, + } ) - await script_obj.async_run({"is_world": "yes"}, context=context) + for run_mode in _ALL_RUN_MODES: + events = [] - await hass.async_block_till_done() + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data == { - "dict": {1: "yes", 2: "yesyes", 3: "yesyesyes"}, - "list": ["yes", "yesyes"], - } - assert not script_obj.can_cancel + assert not script_obj.can_cancel + + await script_obj.async_run({"is_world": "yes"}, context=context) + + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].context is context + assert events[0].data == { + "dict": {1: "yes", 2: "yesyes", 3: "yesyesyes"}, + "list": ["yes", "yesyes"], + } -async def test_calling_service(hass): +async def test_calling_service_basic(hass): """Test the calling of a service.""" - calls = [] context = Context() @callback @@ -104,25 +118,76 @@ async def test_calling_service(hass): hass.services.async_register("test", "script", record_call) - hass.async_add_job( - ft.partial( - script.call_from_config, - hass, - {"service": "test.script", "data": {"hello": "world"}}, - context=context, - ) - ) + schema = cv.SCRIPT_SCHEMA({"service": "test.script", "data": {"hello": "world"}}) - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + calls = [] - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get("hello") == "world" + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + assert not script_obj.can_cancel + + await script_obj.async_run(context=context) + + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get("hello") == "world" + + +async def test_cancel_no_wait(hass, caplog): + """Test stopping script.""" + event = "test_event" + + async def async_simulate_long_service(service): + """Simulate a service that takes a not insignificant time.""" + await asyncio.sleep(0.01) + + hass.services.async_register("test", "script", async_simulate_long_service) + + @callback + def monitor_event(event): + """Signal event happened.""" + event_sem.release() + + hass.bus.async_listen(event, monitor_event) + + schema = cv.SCRIPT_SCHEMA([{"event": event}, {"service": "test.script"}]) + + for run_mode in _ALL_RUN_MODES: + event_sem = asyncio.Semaphore(0) + + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + tasks = [] + for _ in range(3): + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + tasks.append(hass.async_create_task(event_sem.acquire())) + await asyncio.wait_for(asyncio.gather(*tasks), 1) + + # Can't assert just yet because we haven't verified stopping works yet. + # If assert fails we can hang test if async_stop doesn't work. + script_was_runing = script_obj.is_running + + await script_obj.async_stop() + await hass.async_block_till_done() + + assert script_was_runing + assert not script_obj.is_running async def test_activating_scene(hass): """Test the activation of a scene.""" - calls = [] context = Context() @callback @@ -132,22 +197,29 @@ async def test_activating_scene(hass): hass.services.async_register(scene.DOMAIN, SERVICE_TURN_ON, record_call) - hass.async_add_job( - ft.partial( - script.call_from_config, hass, {"scene": "scene.hello"}, context=context - ) - ) + schema = cv.SCRIPT_SCHEMA({"scene": "scene.hello"}) - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + calls = [] - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get(ATTR_ENTITY_ID) == "scene.hello" + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + assert not script_obj.can_cancel + + await script_obj.async_run(context=context) + + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get(ATTR_ENTITY_ID) == "scene.hello" async def test_calling_service_template(hass): """Test the calling of a service.""" - calls = [] context = Context() @callback @@ -157,45 +229,179 @@ async def test_calling_service_template(hass): hass.services.async_register("test", "script", record_call) - hass.async_add_job( - ft.partial( - script.call_from_config, - hass, - { - "service_template": """ - {% if True %} - test.script + schema = cv.SCRIPT_SCHEMA( + { + "service_template": """ + {% if True %} + test.script + {% else %} + test.not_script + {% endif %}""", + "data_template": { + "hello": """ + {% if is_world == 'yes' %} + world {% else %} - test.not_script - {% endif %}""", - "data_template": { - "hello": """ - {% if is_world == 'yes' %} - world - {% else %} - not world - {% endif %} - """ - }, + not world + {% endif %} + """ }, - {"is_world": "yes"}, - context=context, + } + ) + + for run_mode in _ALL_RUN_MODES: + calls = [] + + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + assert not script_obj.can_cancel + + await script_obj.async_run({"is_world": "yes"}, context=context) + + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context is context + assert calls[0].data.get("hello") == "world" + + +async def test_multiple_runs_no_wait(hass): + """Test multiple runs with no wait in script.""" + logger = logging.getLogger("TEST") + + async def async_simulate_long_service(service): + """Simulate a service that takes a not insignificant time.""" + + @callback + def service_done_cb(event): + logger.debug("simulated service (%s:%s) done", fire, listen) + service_done.set() + + calls.append(service) + + fire = service.data.get("fire") + listen = service.data.get("listen") + logger.debug("simulated service (%s:%s) started", fire, listen) + + service_done = asyncio.Event() + unsub = hass.bus.async_listen(listen, service_done_cb) + + hass.bus.async_fire(fire) + + await service_done.wait() + unsub() + + hass.services.async_register("test", "script", async_simulate_long_service) + + heard_event = asyncio.Event() + + @callback + def heard_event_cb(event): + logger.debug("heard: %s", event) + heard_event.set() + + schema = cv.SCRIPT_SCHEMA( + [ + { + "service": "test.script", + "data_template": {"fire": "{{ fire1 }}", "listen": "{{ listen1 }}"}, + }, + { + "service": "test.script", + "data_template": {"fire": "{{ fire2 }}", "listen": "{{ listen2 }}"}, + }, + ] + ) + + for run_mode in _ALL_RUN_MODES: + calls = [] + heard_event.clear() + + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + # Start script twice in such a way that second run will be started while first + # run is in the middle of the first service call. + + unsub = hass.bus.async_listen("1", heard_event_cb) + + logger.debug("starting 1st script") + coro = script_obj.async_run( + {"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"} ) - ) + if run_mode == "background": + await coro + else: + hass.async_create_task(coro) + await asyncio.wait_for(heard_event.wait(), 1) - await hass.async_block_till_done() + unsub() - assert len(calls) == 1 - assert calls[0].context is context - assert calls[0].data.get("hello") == "world" + logger.debug("starting 2nd script") + await script_obj.async_run( + {"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"} + ) + + await hass.async_block_till_done() + + assert len(calls) == 4 -async def test_delay(hass): +async def test_delay_basic(hass): """Test the delay.""" - event = "test_event" - events = [] - context = Context() delay_alias = "delay step" + delay_started_flag = asyncio.Event() + + @callback + def delay_started_cb(): + delay_started_flag.set() + + delay = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA({"delay": delay, "alias": delay_alias}) + + for run_mode in _ALL_RUN_MODES: + delay_started_flag.clear() + + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=delay_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=delay_started_cb, run_mode=run_mode + ) + + assert script_obj.can_cancel + + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) + + assert script_obj.is_running + assert script_obj.last_action == delay_alias + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + delay + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert script_obj.last_action is None + + +async def test_multiple_runs_delay(hass): + """Test multiple runs with delay in script.""" + event = "test_event" + delay_started_flag = asyncio.Event() @callback def record_event(event): @@ -204,79 +410,105 @@ async def test_delay(hass): hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"delay": {"seconds": 5}, "alias": delay_alias}, - {"event": event}, - ] - ), + @callback + def delay_started_cb(): + delay_started_flag.set() + + delay = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"delay": delay}, + {"event": event, "event_data": {"value": 2}}, + ] ) - await script_obj.async_run(context=context) - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] + delay_started_flag.clear() - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == delay_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=delay_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=delay_started_cb, run_mode=run_mode + ) - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 2 - assert events[0].context is context - assert events[1].context is context + assert script_obj.is_running + assert len(events) == 1 + assert events[-1].data["value"] == 1 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + # Start second run of script while first run is in a delay. + await script_obj.async_run() + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + delay + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + if run_mode in (None, "legacy"): + assert len(events) == 2 + else: + assert len(events) == 4 + assert events[-3].data["value"] == 1 + assert events[-2].data["value"] == 2 + assert events[-1].data["value"] == 2 -async def test_delay_template(hass): +async def test_delay_template_ok(hass): """Test the delay as a template.""" - event = "test_event" - events = [] - delay_alias = "delay step" + delay_started_flag = asyncio.Event() @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) + def delay_started_cb(): + delay_started_flag.set() - hass.bus.async_listen(event, record_event) + schema = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 1 }}"}) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"delay": "00:00:{{ 5 }}", "alias": delay_alias}, - {"event": event}, - ] - ), - ) + for run_mode in _ALL_RUN_MODES: + delay_started_flag.clear() - await script_obj.async_run() - await hass.async_block_till_done() + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=delay_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=delay_started_cb, run_mode=run_mode + ) - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == delay_alias - assert len(events) == 1 + assert script_obj.can_cancel - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + timedelta(seconds=1) + async_fire_time_changed(hass, future) + await hass.async_block_till_done() - assert not script_obj.is_running - assert len(events) == 2 + assert not script_obj.is_running -async def test_delay_invalid_template(hass): +async def test_delay_template_invalid(hass, caplog): """Test the delay as a template that fails.""" event = "test_event" - events = [] @callback def record_event(event): @@ -285,71 +517,82 @@ async def test_delay_invalid_template(hass): hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"delay": "{{ invalid_delay }}"}, - {"delay": {"seconds": 5}}, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + {"delay": "{{ invalid_delay }}"}, + {"delay": {"seconds": 5}}, + {"event": event}, + ] ) - with mock.patch.object(script, "_LOGGER") as mock_logger: + for run_mode in _ALL_RUN_MODES: + events = [] + + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + start_idx = len(caplog.records) + await script_obj.async_run() await hass.async_block_till_done() - assert mock_logger.error.called - assert not script_obj.is_running - assert len(events) == 1 + assert any( + rec.levelname == "ERROR" and "Error rendering" in rec.message + for rec in caplog.records[start_idx:] + ) + + assert not script_obj.is_running + assert len(events) == 1 -async def test_delay_complex_template(hass): +async def test_delay_template_complex_ok(hass): """Test the delay with a working complex template.""" - event = "test_event" - events = [] - delay_alias = "delay step" + delay_started_flag = asyncio.Event() @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) + def delay_started_cb(): + delay_started_flag.set() - hass.bus.async_listen(event, record_event) + milliseconds = 10 + schema = cv.SCRIPT_SCHEMA({"delay": {"milliseconds": "{{ milliseconds }}"}}) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"delay": {"seconds": "{{ 5 }}"}, "alias": delay_alias}, - {"event": event}, - ] - ), - ) + for run_mode in _ALL_RUN_MODES: + delay_started_flag.clear() - await script_obj.async_run() - await hass.async_block_till_done() + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=delay_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=delay_started_cb, run_mode=run_mode + ) - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == delay_alias - assert len(events) == 1 + assert script_obj.can_cancel - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + coro = script_obj.async_run({"milliseconds": milliseconds}) + if run_mode == "background": + await coro + else: + hass.async_create_task(coro) + await asyncio.wait_for(delay_started_flag.wait(), 1) + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + timedelta(milliseconds=milliseconds) + async_fire_time_changed(hass, future) + await hass.async_block_till_done() - assert not script_obj.is_running - assert len(events) == 2 + assert not script_obj.is_running -async def test_delay_complex_invalid_template(hass): +async def test_delay_template_complex_invalid(hass, caplog): """Test the delay with a complex template that fails.""" event = "test_event" - events = [] @callback def record_event(event): @@ -358,31 +601,44 @@ async def test_delay_complex_invalid_template(hass): hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"delay": {"seconds": "{{ invalid_delay }}"}}, - {"delay": {"seconds": "{{ 5 }}"}}, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + {"delay": {"seconds": "{{ invalid_delay }}"}}, + {"delay": {"seconds": 5}}, + {"event": event}, + ] ) - with mock.patch.object(script, "_LOGGER") as mock_logger: + for run_mode in _ALL_RUN_MODES: + events = [] + + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + start_idx = len(caplog.records) + await script_obj.async_run() await hass.async_block_till_done() - assert mock_logger.error.called - assert not script_obj.is_running - assert len(events) == 1 + assert any( + rec.levelname == "ERROR" and "Error rendering" in rec.message + for rec in caplog.records[start_idx:] + ) + + assert not script_obj.is_running + assert len(events) == 1 -async def test_cancel_while_delay(hass): +async def test_cancel_delay(hass): """Test the cancelling while the delay is present.""" + delay_started_flag = asyncio.Event() event = "test_event" - events = [] + + @callback + def delay_started_cb(): + delay_started_flag.set() @callback def record_event(event): @@ -391,35 +647,101 @@ async def test_cancel_while_delay(hass): hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, cv.SCRIPT_SCHEMA([{"delay": {"seconds": 5}}, {"event": event}]) - ) + delay = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA([{"delay": delay}, {"event": event}]) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + delay_started_flag.clear() + events = [] - assert script_obj.is_running - assert len(events) == 0 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=delay_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=delay_started_cb, run_mode=run_mode + ) - script_obj.async_stop() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(delay_started_flag.wait(), 1) - assert not script_obj.is_running + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + await script_obj.async_stop() - # Make sure the script is really stopped. - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + assert not script_obj.is_running - assert not script_obj.is_running - assert len(events) == 0 + # Make sure the script is really stopped. + + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + delay + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 0 -async def test_wait_template(hass): +async def test_wait_template_basic(hass): """Test the wait template.""" - event = "test_event" - events = [] - context = Context() wait_alias = "wait step" + wait_started_flag = asyncio.Event() + + @callback + def wait_started_cb(): + wait_started_flag.set() + + schema = cv.SCRIPT_SCHEMA( + { + "wait_template": "{{ states.switch.test.state == 'off' }}", + "alias": wait_alias, + } + ) + + for run_mode in _ALL_RUN_MODES: + wait_started_flag.clear() + hass.states.async_set("switch.test", "on") + + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) + + assert script_obj.can_cancel + + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert script_obj.last_action == wait_alias + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert script_obj.last_action is None + + +async def test_multiple_runs_wait_template(hass): + """Test multiple runs with wait_template in script.""" + event = "test_event" + wait_started_flag = asyncio.Event() @callback def record_event(event): @@ -428,44 +750,70 @@ async def test_wait_template(hass): hass.bus.async_listen(event, record_event) - hass.states.async_set("switch.test", "on") + @callback + def wait_started_cb(): + wait_started_flag.set() - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "wait_template": "{{states.switch.test.state == 'off'}}", - "alias": wait_alias, - }, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] ) - await script_obj.async_run(context=context) - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] + wait_started_flag.clear() + hass.states.async_set("switch.test", "on") - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 2 - assert events[0].context is context - assert events[1].context is context + assert script_obj.is_running + assert len(events) == 1 + assert events[-1].data["value"] == 1 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + # Start second run of script while first run is in wait_template. + if run_mode == "blocking": + hass.async_create_task(script_obj.async_run()) + else: + await script_obj.async_run() + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + if run_mode in (None, "legacy"): + assert len(events) == 2 + else: + assert len(events) == 4 + assert events[-3].data["value"] == 1 + assert events[-2].data["value"] == 2 + assert events[-1].data["value"] == 2 -async def test_wait_template_cancel(hass): - """Test the wait template cancel action.""" +async def test_cancel_wait_template(hass): + """Test the cancelling while wait_template is present.""" + wait_started_flag = asyncio.Event() event = "test_event" - events = [] - wait_alias = "wait step" + + @callback + def wait_started_cb(): + wait_started_flag.set() @callback def record_event(event): @@ -474,46 +822,54 @@ async def test_wait_template_cancel(hass): hass.bus.async_listen(event, record_event) - hass.states.async_set("switch.test", "on") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "wait_template": "{{states.switch.test.state == 'off'}}", - "alias": wait_alias, - }, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + wait_started_flag.clear() + events = [] + hass.states.async_set("switch.test", "on") - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - script_obj.async_stop() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 1 + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + await script_obj.async_stop() - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() + assert not script_obj.is_running - assert not script_obj.is_running - assert len(events) == 1 + # Make sure the script is really stopped. + + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 0 async def test_wait_template_not_schedule(hass): """Test the wait template with correct condition.""" event = "test_event" - events = [] @callback def record_event(event): @@ -524,30 +880,33 @@ async def test_wait_template_not_schedule(hass): hass.states.async_set("switch.test", "on") - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"wait_template": "{{states.switch.test.state == 'on'}}"}, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + {"wait_template": "{{ states.switch.test.state == 'on' }}"}, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] - assert not script_obj.is_running - assert script_obj.can_cancel - assert len(events) == 2 + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + await script_obj.async_run() + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 async def test_wait_template_timeout_halt(hass): """Test the wait template, halt on timeout.""" event = "test_event" - events = [] - wait_alias = "wait step" + wait_started_flag = asyncio.Event() @callback def record_event(event): @@ -556,45 +915,61 @@ async def test_wait_template_timeout_halt(hass): hass.bus.async_listen(event, record_event) + @callback + def wait_started_cb(): + wait_started_flag.set() + hass.states.async_set("switch.test", "on") - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "wait_template": "{{states.switch.test.state == 'off'}}", - "continue_on_timeout": False, - "timeout": 5, - "alias": wait_alias, - }, - {"event": event}, - ] - ), + timeout = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA( + [ + { + "wait_template": "{{ states.switch.test.state == 'off' }}", + "continue_on_timeout": False, + "timeout": timeout, + }, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] + wait_started_flag.clear() - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 1 + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + timeout + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 0 async def test_wait_template_timeout_continue(hass): """Test the wait template with continuing the script.""" event = "test_event" - events = [] - wait_alias = "wait step" + wait_started_flag = asyncio.Event() @callback def record_event(event): @@ -603,45 +978,61 @@ async def test_wait_template_timeout_continue(hass): hass.bus.async_listen(event, record_event) + @callback + def wait_started_cb(): + wait_started_flag.set() + hass.states.async_set("switch.test", "on") - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "wait_template": "{{states.switch.test.state == 'off'}}", - "timeout": 5, - "continue_on_timeout": True, - "alias": wait_alias, - }, - {"event": event}, - ] - ), + timeout = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA( + [ + { + "wait_template": "{{ states.switch.test.state == 'off' }}", + "continue_on_timeout": True, + "timeout": timeout, + }, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] + wait_started_flag.clear() - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 2 + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + timeout + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 1 async def test_wait_template_timeout_default(hass): - """Test the wait template with default contiune.""" + """Test the wait template with default continue.""" event = "test_event" - events = [] - wait_alias = "wait step" + wait_started_flag = asyncio.Event() @callback def record_event(event): @@ -650,128 +1041,99 @@ async def test_wait_template_timeout_default(hass): hass.bus.async_listen(event, record_event) + @callback + def wait_started_cb(): + wait_started_flag.set() + hass.states.async_set("switch.test", "on") - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "wait_template": "{{states.switch.test.state == 'off'}}", - "timeout": 5, - "alias": wait_alias, - }, - {"event": event}, - ] - ), + timeout = timedelta(milliseconds=10) + schema = cv.SCRIPT_SCHEMA( + [ + { + "wait_template": "{{ states.switch.test.state == 'off' }}", + "timeout": timeout, + }, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() + for run_mode in _ALL_RUN_MODES: + events = [] + wait_started_flag.clear() - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() + try: + if run_mode == "background": + await script_obj.async_run() + else: + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) - assert not script_obj.is_running - assert len(events) == 2 + assert script_obj.is_running + assert len(events) == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + if run_mode in (None, "legacy"): + future = dt_util.utcnow() + timeout + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 1 async def test_wait_template_variables(hass): """Test the wait template with variables.""" - event = "test_event" - events = [] - wait_alias = "wait step" + wait_started_flag = asyncio.Event() @callback - def record_event(event): - """Add recorded event to set.""" - events.append(event) + def wait_started_cb(): + wait_started_flag.set() - hass.bus.async_listen(event, record_event) + schema = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"}) - hass.states.async_set("switch.test", "on") + for run_mode in _ALL_RUN_MODES: + wait_started_flag.clear() + hass.states.async_set("switch.test", "on") - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - {"wait_template": "{{is_state(data, 'off')}}", "alias": wait_alias}, - {"event": event}, - ] - ), - ) + if run_mode is None: + script_obj = script.Script(hass, schema, change_listener=wait_started_cb) + else: + script_obj = script.Script( + hass, schema, change_listener=wait_started_cb, run_mode=run_mode + ) - await script_obj.async_run({"data": "switch.test"}) - await hass.async_block_till_done() + assert script_obj.can_cancel - assert script_obj.is_running - assert script_obj.can_cancel - assert script_obj.last_action == wait_alias - assert len(events) == 1 + try: + coro = script_obj.async_run({"data": "switch.test"}) + if run_mode == "background": + await coro + else: + hass.async_create_task(coro) + await asyncio.wait_for(wait_started_flag.wait(), 1) - hass.states.async_set("switch.test", "off") - await hass.async_block_till_done() + assert script_obj.is_running + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() - assert not script_obj.is_running - assert len(events) == 2 + assert not script_obj.is_running -async def test_passing_variables_to_script(hass): - """Test if we can pass variables to script.""" - calls = [] - - @callback - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - hass.services.async_register("test", "script", record_call) - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - { - "service": "test.script", - "data_template": {"hello": "{{ greeting }}"}, - }, - {"delay": "{{ delay_period }}"}, - { - "service": "test.script", - "data_template": {"hello": "{{ greeting2 }}"}, - }, - ] - ), - ) - - await script_obj.async_run( - {"greeting": "world", "greeting2": "universe", "delay_period": "00:00:05"} - ) - - await hass.async_block_till_done() - - assert script_obj.is_running - assert len(calls) == 1 - assert calls[-1].data["hello"] == "world" - - future = dt_util.utcnow() + timedelta(seconds=5) - async_fire_time_changed(hass, future) - await hass.async_block_till_done() - - assert not script_obj.is_running - assert len(calls) == 2 - assert calls[-1].data["hello"] == "universe" - - -async def test_condition(hass): +async def test_condition_basic(hass): """Test if we can use conditions in a script.""" event = "test_event" events = [] @@ -783,31 +1145,39 @@ async def test_condition(hass): hass.bus.async_listen(event, record_event) - hass.states.async_set("test.entity", "hello") - - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [ - {"event": event}, - { - "condition": "template", - "value_template": '{{ states.test.entity.state == "hello" }}', - }, - {"event": event}, - ] - ), + schema = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + { + "condition": "template", + "value_template": "{{ states.test.entity.state == 'hello' }}", + }, + {"event": event}, + ] ) - await script_obj.async_run() - await hass.async_block_till_done() - assert len(events) == 2 + for run_mode in _ALL_RUN_MODES: + events = [] + hass.states.async_set("test.entity", "hello") - hass.states.async_set("test.entity", "goodbye") + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) - await script_obj.async_run() - await hass.async_block_till_done() - assert len(events) == 3 + assert not script_obj.can_cancel + + await script_obj.async_run() + await hass.async_block_till_done() + + assert len(events) == 2 + + hass.states.async_set("test.entity", "goodbye") + + await script_obj.async_run() + await hass.async_block_till_done() + + assert len(events) == 3 @asynctest.patch("homeassistant.helpers.script.condition.async_from_config") @@ -846,7 +1216,7 @@ async def test_condition_created_once(async_from_config, hass): assert len(script_obj._config_cache) == 1 -async def test_all_conditions_cached(hass): +async def test_condition_all_cached(hass): """Test that multiple conditions get cached.""" event = "test_event" events = [] @@ -887,55 +1257,63 @@ async def test_last_triggered(hass): """Test the last_triggered.""" event = "test_event" - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [{"event": event}, {"delay": {"seconds": 5}}, {"event": event}] - ), - ) + schema = cv.SCRIPT_SCHEMA({"event": event}) - assert script_obj.last_triggered is None + for run_mode in _ALL_RUN_MODES: + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) - time = dt_util.utcnow() - with mock.patch("homeassistant.helpers.script.date_util.utcnow", return_value=time): - await script_obj.async_run() - await hass.async_block_till_done() + assert script_obj.last_triggered is None - assert script_obj.last_triggered == time + time = dt_util.utcnow() + with mock.patch("homeassistant.helpers.script.utcnow", return_value=time): + await script_obj.async_run() + await hass.async_block_till_done() + + assert script_obj.last_triggered == time async def test_propagate_error_service_not_found(hass): """Test that a script aborts when a service is not found.""" - events = [] + event = "test_event" @callback def record_event(event): events.append(event) - hass.bus.async_listen("test_event", record_event) + hass.bus.async_listen(event, record_event) - script_obj = script.Script( - hass, cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": "test_event"}]) - ) + schema = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) - with pytest.raises(exceptions.ServiceNotFound): - await script_obj.async_run() + run_modes = _ALL_RUN_MODES + if "background" in run_modes: + run_modes.remove("background") + for run_mode in run_modes: + events = [] - assert len(events) == 0 - assert script_obj._cur == -1 + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + with pytest.raises(exceptions.ServiceNotFound): + await script_obj.async_run() + + assert len(events) == 0 + assert not script_obj.is_running async def test_propagate_error_invalid_service_data(hass): """Test that a script aborts when we send invalid service data.""" - events = [] + event = "test_event" @callback def record_event(event): events.append(event) - hass.bus.async_listen("test_event", record_event) - - calls = [] + hass.bus.async_listen(event, record_event) @callback def record_call(service): @@ -946,32 +1324,39 @@ async def test_propagate_error_invalid_service_data(hass): "test", "script", record_call, schema=vol.Schema({"text": str}) ) - script_obj = script.Script( - hass, - cv.SCRIPT_SCHEMA( - [{"service": "test.script", "data": {"text": 1}}, {"event": "test_event"}] - ), + schema = cv.SCRIPT_SCHEMA( + [{"service": "test.script", "data": {"text": 1}}, {"event": event}] ) - with pytest.raises(vol.Invalid): - await script_obj.async_run() + run_modes = _ALL_RUN_MODES + if "background" in run_modes: + run_modes.remove("background") + for run_mode in run_modes: + events = [] + calls = [] - assert len(events) == 0 - assert len(calls) == 0 - assert script_obj._cur == -1 + if run_mode is None: + script_obj = script.Script(hass, schema) + else: + script_obj = script.Script(hass, schema, run_mode=run_mode) + + with pytest.raises(vol.Invalid): + await script_obj.async_run() + + assert len(events) == 0 + assert len(calls) == 0 + assert not script_obj.is_running async def test_propagate_error_service_exception(hass): """Test that a script aborts when a service throws an exception.""" - events = [] + event = "test_event" @callback def record_event(event): events.append(event) - hass.bus.async_listen("test_event", record_event) - - calls = [] + hass.bus.async_listen(event, record_event) @callback def record_call(service): @@ -980,48 +1365,24 @@ async def test_propagate_error_service_exception(hass): hass.services.async_register("test", "script", record_call) - script_obj = script.Script( - hass, cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": "test_event"}]) - ) + schema = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}]) - with pytest.raises(ValueError): - await script_obj.async_run() + run_modes = _ALL_RUN_MODES + if "background" in run_modes: + run_modes.remove("background") + for run_mode in run_modes: + events = [] - assert len(events) == 0 - assert len(calls) == 0 - assert script_obj._cur == -1 - - -def test_log_exception(): - """Test logged output.""" - script_obj = script.Script( - None, cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": "test_event"}]) - ) - script_obj._exception_step = 1 - - for exc, msg in ( - (vol.Invalid("Invalid number"), "Invalid data"), - ( - exceptions.TemplateError(jinja2.TemplateError("Unclosed bracket")), - "Error rendering template", - ), - (exceptions.Unauthorized(), "Unauthorized"), - (exceptions.ServiceNotFound("light", "turn_on"), "Service not found"), - (ValueError("Cannot parse JSON"), "Unknown error"), - ): - logger = mock.Mock() - script_obj.async_log_exception(logger, "Test error", exc) - - assert len(logger.mock_calls) == 1 - _, _, p_error_desc, p_action_type, p_step, p_error = logger.mock_calls[0][1] - - assert p_error_desc == msg - assert p_action_type == script.ACTION_FIRE_EVENT - assert p_step == 2 - if isinstance(exc, ValueError): - assert p_error == "" + if run_mode is None: + script_obj = script.Script(hass, schema) else: - assert p_error == str(exc) + script_obj = script.Script(hass, schema, run_mode=run_mode) + + with pytest.raises(ValueError): + await script_obj.async_run() + + assert len(events) == 0 + assert not script_obj.is_running async def test_referenced_entities(): @@ -1078,3 +1439,307 @@ async def test_referenced_devices(): assert script_obj.referenced_devices == {"script-dev-id", "condition-dev-id"} # Test we cache results. assert script_obj.referenced_devices is script_obj.referenced_devices + + +async def test_if_running_with_legacy_run_mode(hass, caplog): + """Test using if_running with run_mode='legacy'.""" + # TODO: REMOVE + if _ALL_RUN_MODES == [None]: + return + + with pytest.raises(exceptions.HomeAssistantError): + script.Script( + hass, + [], + if_running="ignore", + run_mode="legacy", + logger=logging.getLogger("TEST"), + ) + assert any( + rec.levelname == "ERROR" + and rec.name == "TEST" + and all(text in rec.message for text in ("if_running", "legacy")) + for rec in caplog.records + ) + + +async def test_if_running_ignore(hass, caplog): + """Test overlapping runs with if_running='ignore'.""" + # TODO: REMOVE + if _ALL_RUN_MODES == [None]: + return + + event = "test_event" + events = [] + wait_started_flag = asyncio.Event() + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + hass.bus.async_listen(event, record_event) + + @callback + def wait_started_cb(): + wait_started_flag.set() + + hass.states.async_set("switch.test", "on") + + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] + ), + change_listener=wait_started_cb, + if_running="ignore", + run_mode="background", + logger=logging.getLogger("TEST"), + ) + + try: + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 1 + assert events[0].data["value"] == 1 + + # Start second run of script while first run is suspended in wait_template. + # This should ignore second run. + + await script_obj.async_run() + + assert script_obj.is_running + assert any( + rec.levelname == "INFO" and rec.name == "TEST" and "Skipping" in rec.message + for rec in caplog.records + ) + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 + assert events[1].data["value"] == 2 + + +async def test_if_running_error(hass, caplog): + """Test overlapping runs with if_running='error'.""" + # TODO: REMOVE + if _ALL_RUN_MODES == [None]: + return + + event = "test_event" + events = [] + wait_started_flag = asyncio.Event() + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + hass.bus.async_listen(event, record_event) + + @callback + def wait_started_cb(): + wait_started_flag.set() + + hass.states.async_set("switch.test", "on") + + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] + ), + change_listener=wait_started_cb, + if_running="error", + run_mode="background", + logger=logging.getLogger("TEST"), + ) + + try: + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 1 + assert events[0].data["value"] == 1 + + # Start second run of script while first run is suspended in wait_template. + # This should cause an error. + + with pytest.raises(exceptions.HomeAssistantError): + await script_obj.async_run() + + assert script_obj.is_running + assert any( + rec.levelname == "ERROR" + and rec.name == "TEST" + and "Already running" in rec.message + for rec in caplog.records + ) + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 + assert events[1].data["value"] == 2 + + +async def test_if_running_restart(hass, caplog): + """Test overlapping runs with if_running='restart'.""" + # TODO: REMOVE + if _ALL_RUN_MODES == [None]: + return + + event = "test_event" + events = [] + wait_started_flag = asyncio.Event() + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + hass.bus.async_listen(event, record_event) + + @callback + def wait_started_cb(): + wait_started_flag.set() + + hass.states.async_set("switch.test", "on") + + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] + ), + change_listener=wait_started_cb, + if_running="restart", + run_mode="background", + logger=logging.getLogger("TEST"), + ) + + try: + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 1 + assert events[0].data["value"] == 1 + + # Start second run of script while first run is suspended in wait_template. + # This should stop first run then start a new run. + + wait_started_flag.clear() + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 2 + assert events[1].data["value"] == 1 + assert any( + rec.levelname == "INFO" + and rec.name == "TEST" + and "Restarting" in rec.message + for rec in caplog.records + ) + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 3 + assert events[2].data["value"] == 2 + + +async def test_if_running_parallel(hass): + """Test overlapping runs with if_running='parallel'.""" + # TODO: REMOVE + if _ALL_RUN_MODES == [None]: + return + + event = "test_event" + events = [] + wait_started_flag = asyncio.Event() + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + hass.bus.async_listen(event, record_event) + + @callback + def wait_started_cb(): + wait_started_flag.set() + + hass.states.async_set("switch.test", "on") + + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 1}}, + {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + {"event": event, "event_data": {"value": 2}}, + ] + ), + change_listener=wait_started_cb, + if_running="parallel", + run_mode="background", + logger=logging.getLogger("TEST"), + ) + + try: + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 1 + assert events[0].data["value"] == 1 + + # Start second run of script while first run is suspended in wait_template. + # This should start a new, independent run. + + wait_started_flag.clear() + await script_obj.async_run() + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + assert len(events) == 2 + assert events[1].data["value"] == 1 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 4 + assert events[2].data["value"] == 2 + assert events[3].data["value"] == 2