diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index e890766cf2c..29d1acf0316 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -35,6 +35,7 @@ from homeassistant.const import ( CONF_UNTIL, CONF_WAIT_TEMPLATE, CONF_WHILE, + EVENT_HOMEASSISTANT_STOP, SERVICE_TURN_ON, ) from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback @@ -43,7 +44,7 @@ from homeassistant.helpers import ( config_validation as cv, template as template, ) -from homeassistant.helpers.event import async_track_template +from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.service import ( CONF_SERVICE_DATA, async_prepare_call_from_config, @@ -73,9 +74,15 @@ ATTR_CUR = "current" ATTR_MAX = "max" ATTR_MODE = "mode" +DATA_SCRIPTS = "helpers.script" + +_LOGGER = logging.getLogger(__name__) + _LOG_EXCEPTION = logging.ERROR + 1 _TIMEOUT_MSG = "Timeout reached, abort script." +_SHUTDOWN_MAX_WAIT = 60 + def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): """Make a schema for a component that uses the script helper.""" @@ -545,6 +552,41 @@ class _QueuedScriptRun(_ScriptRun): super()._finish() +async def _async_stop_scripts_after_shutdown(hass, point_in_time): + """Stop running Script objects started after shutdown.""" + running_scripts = [ + script for script in hass.data[DATA_SCRIPTS] if script["instance"].is_running + ] + if running_scripts: + names = ", ".join([script["instance"].name for script in running_scripts]) + _LOGGER.warning("Stopping scripts running too long after shutdown: %s", names) + await asyncio.gather( + *[ + script["instance"].async_stop(update_state=False) + for script in running_scripts + ] + ) + + +async def _async_stop_scripts_at_shutdown(hass, event): + """Stop running Script objects started before shutdown.""" + async_call_later( + hass, _SHUTDOWN_MAX_WAIT, partial(_async_stop_scripts_after_shutdown, hass) + ) + + running_scripts = [ + script + for script in hass.data[DATA_SCRIPTS] + if script["instance"].is_running and script["started_before_shutdown"] + ] + if running_scripts: + names = ", ".join([script["instance"].name for script in running_scripts]) + _LOGGER.debug("Stopping scripts running at shutdown: %s", names) + await asyncio.gather( + *[script["instance"].async_stop() for script in running_scripts] + ) + + class Script: """Representation of a script.""" @@ -558,8 +600,20 @@ class Script: max_runs: int = DEFAULT_MAX, logger: Optional[logging.Logger] = None, log_exceptions: bool = True, + top_level: bool = True, ) -> None: """Initialize the script.""" + all_scripts = hass.data.get(DATA_SCRIPTS) + if not all_scripts: + all_scripts = hass.data[DATA_SCRIPTS] = [] + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass) + ) + if top_level: + all_scripts.append( + {"instance": self, "started_before_shutdown": not hass.is_stopping} + ) + self._hass = hass self.sequence = sequence template.attach(hass, self.sequence) @@ -732,6 +786,7 @@ class Script: f"{self.name}: {step_name}", script_mode=SCRIPT_MODE_PARALLEL, logger=self._logger, + top_level=False, ) sub_script.change_listener = partial(self._chain_change_listener, sub_script) return sub_script @@ -758,6 +813,7 @@ class Script: f"{self.name}: {step_name}: choice {idx}", script_mode=SCRIPT_MODE_PARALLEL, logger=self._logger, + top_level=False, ) sub_script.change_listener = partial( self._chain_change_listener, sub_script @@ -771,6 +827,7 @@ class Script: f"{self.name}: {step_name}: default", script_mode=SCRIPT_MODE_PARALLEL, logger=self._logger, + top_level=False, ) default_script.change_listener = partial( self._chain_change_listener, default_script