Remove legacy script mode and simplify remaining modes (#37729)

This commit is contained in:
Phil Bruckner 2020-07-10 19:00:57 -05:00 committed by GitHub
parent 8a8289b1a4
commit 63e55bff52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 407 additions and 959 deletions

View File

@ -15,7 +15,6 @@ from homeassistant.const import (
CONF_ID,
CONF_MODE,
CONF_PLATFORM,
CONF_QUEUE_SIZE,
CONF_ZONE,
EVENT_HOMEASSISTANT_STARTED,
SERVICE_RELOAD,
@ -31,7 +30,12 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.script import SCRIPT_BASE_SCHEMA, Script, validate_queue_size
from homeassistant.helpers.script import (
CONF_MAX,
SCRIPT_MODE_PARALLEL,
Script,
make_script_schema,
)
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
@ -99,7 +103,7 @@ _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
PLATFORM_SCHEMA = vol.All(
cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"),
SCRIPT_BASE_SCHEMA.extend(
make_script_schema(
{
# str on purpose
CONF_ID: str,
@ -110,9 +114,9 @@ PLATFORM_SCHEMA = vol.All(
vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
}
},
SCRIPT_MODE_PARALLEL,
),
validate_queue_size,
)
@ -507,7 +511,7 @@ async def _async_process_config(hass, config, component):
config_block[CONF_ACTION],
name,
script_mode=config_block[CONF_MODE],
queue_size=config_block.get(CONF_QUEUE_SIZE, 0),
max_runs=config_block[CONF_MAX],
logger=_LOGGER,
)

View File

@ -1,7 +1,6 @@
"""Config validation helper for the automation integration."""
import asyncio
import importlib
import logging
import voluptuous as vol
@ -9,21 +8,14 @@ from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig,
)
from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.const import CONF_ALIAS, CONF_ID, CONF_MODE, CONF_PLATFORM
from homeassistant.const import CONF_PLATFORM
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, config_per_platform
from homeassistant.helpers.script import (
SCRIPT_MODE_LEGACY,
async_validate_action_config,
validate_legacy_mode_actions,
warn_deprecated_legacy,
)
from homeassistant.helpers.script import async_validate_action_config
from homeassistant.loader import IntegrationNotFound
from . import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
@ -56,9 +48,6 @@ async def async_validate_config_item(hass, config, full_config=None):
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
)
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_ACTION])
return config
@ -78,35 +67,6 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
return config
def _deprecated_legacy_mode(config):
legacy_names = []
legacy_unnamed_found = False
for cfg in config[DOMAIN]:
mode = cfg.get(CONF_MODE)
if mode is None:
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
name = cfg.get(CONF_ID) or cfg.get(CONF_ALIAS)
if name:
legacy_names.append(name)
else:
legacy_unnamed_found = True
if legacy_names or legacy_unnamed_found:
msgs = []
if legacy_unnamed_found:
msgs.append("unnamed automations")
if legacy_names:
if len(legacy_names) == 1:
base_msg = "this automation"
else:
base_msg = "these automations"
msgs.append(f"{base_msg}: {', '.join(legacy_names)}")
warn_deprecated_legacy(_LOGGER, " and ".join(msgs))
return config
async def async_validate_config(hass, config):
"""Validate config."""
automations = list(
@ -126,6 +86,4 @@ async def async_validate_config(hass, config):
config = config_without_domain(config, DOMAIN)
config[DOMAIN] = automations
_deprecated_legacy_mode(config)
return config

View File

@ -11,7 +11,6 @@ from homeassistant.const import (
CONF_ALIAS,
CONF_ICON,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_SEQUENCE,
SERVICE_RELOAD,
SERVICE_TOGGLE,
@ -25,12 +24,10 @@ from homeassistant.helpers.config_validation import make_entity_service_schema
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.script import (
SCRIPT_BASE_SCHEMA,
SCRIPT_MODE_LEGACY,
CONF_MAX,
SCRIPT_MODE_SINGLE,
Script,
validate_legacy_mode_actions,
validate_queue_size,
warn_deprecated_legacy,
make_script_schema,
)
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.loader import bind_hass
@ -52,51 +49,24 @@ ENTITY_ID_FORMAT = DOMAIN + ".{}"
EVENT_SCRIPT_STARTED = "script_started"
def _deprecated_legacy_mode(config):
legacy_scripts = []
for object_id, cfg in config.items():
mode = cfg.get(CONF_MODE)
if mode is None:
legacy_scripts.append(object_id)
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
if legacy_scripts:
warn_deprecated_legacy(_LOGGER, f"script(s): {', '.join(legacy_scripts)}")
return config
def _not_supported_in_legacy_mode(config):
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_SEQUENCE])
return config
SCRIPT_ENTRY_SCHEMA = vol.All(
SCRIPT_BASE_SCHEMA.extend(
{
vol.Optional(CONF_ALIAS): cv.string,
vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
vol.Optional(CONF_FIELDS, default={}): {
cv.string: {
vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_EXAMPLE): cv.string,
}
},
}
),
validate_queue_size,
_not_supported_in_legacy_mode,
SCRIPT_ENTRY_SCHEMA = make_script_schema(
{
vol.Optional(CONF_ALIAS): cv.string,
vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
vol.Optional(CONF_FIELDS, default={}): {
cv.string: {
vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_EXAMPLE): cv.string,
}
},
},
SCRIPT_MODE_SINGLE,
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.All(
cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA), _deprecated_legacy_mode
)
},
extra=vol.ALLOW_EXTRA,
{DOMAIN: cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA)}, extra=vol.ALLOW_EXTRA
)
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
@ -192,29 +162,26 @@ async def async_setup(hass, config):
"""Call a service to turn script on."""
variables = service.data.get(ATTR_VARIABLES)
for script_entity in await component.async_extract_from_service(service):
if script_entity.script.is_legacy:
await hass.services.async_call(
DOMAIN, script_entity.object_id, variables, context=service.context
)
else:
await script_entity.async_turn_on(
variables=variables, context=service.context, wait=False
)
await script_entity.async_turn_on(
variables=variables, context=service.context, wait=False
)
async def turn_off_service(service):
"""Cancel a script."""
# Stopping a script is ok to be done in parallel
scripts = await component.async_extract_from_service(service)
script_entities = await component.async_extract_from_service(service)
if not scripts:
if not script_entities:
return
await asyncio.wait([script.async_turn_off() for script in scripts])
await asyncio.wait(
[script_entity.async_turn_off() for script_entity in script_entities]
)
async def toggle_service(service):
"""Toggle a script."""
for script_entity in await component.async_extract_from_service(service):
await script_entity.async_toggle(context=service.context)
await script_entity.async_toggle(context=service.context, wait=False)
hass.services.async_register(
DOMAIN, SERVICE_RELOAD, reload_service, schema=RELOAD_SERVICE_SCHEMA
@ -239,27 +206,14 @@ async def _async_process_config(hass, config, component):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script_entity = component.get_entity(entity_id)
if script_entity.script.is_legacy and script_entity.is_on:
_LOGGER.warning("Script %s already running", entity_id)
return
await script_entity.async_turn_on(
variables=service.data, context=service.context
)
script_entities = []
for object_id, cfg in config.get(DOMAIN, {}).items():
script_entities.append(
ScriptEntity(
hass,
object_id,
cfg.get(CONF_ALIAS, object_id),
cfg.get(CONF_ICON),
cfg[CONF_SEQUENCE],
cfg[CONF_MODE],
cfg.get(CONF_QUEUE_SIZE, 0),
)
)
script_entities = [
ScriptEntity(hass, object_id, cfg)
for object_id, cfg in config.get(DOMAIN, {}).items()
]
await component.async_add_entities(script_entities)
@ -289,18 +243,18 @@ class ScriptEntity(ToggleEntity):
icon = None
def __init__(self, hass, object_id, name, icon, sequence, mode, queue_size):
def __init__(self, hass, object_id, cfg):
"""Initialize the script."""
self.object_id = object_id
self.icon = icon
self.icon = cfg.get(CONF_ICON)
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self.script = Script(
hass,
sequence,
name,
cfg[CONF_SEQUENCE],
cfg.get(CONF_ALIAS, object_id),
self.async_change_listener,
mode,
queue_size,
cfg[CONF_MODE],
cfg[CONF_MAX],
logging.getLogger(f"{__name__}.{object_id}"),
)
self._changed = asyncio.Event()
@ -354,13 +308,10 @@ class ScriptEntity(ToggleEntity):
# Caller does not want to wait for called script to finish so let script run in
# separate Task. However, wait for first state change so we can guarantee that
# it is written to the State Machine before we return. Only do this for
# non-legacy scripts, since legacy scripts don't necessarily change state
# immediately.
# it is written to the State Machine before we return.
self._changed.clear()
self.hass.async_create_task(coro)
if not self.script.is_legacy:
await self._changed.wait()
await self._changed.wait()
async def async_turn_off(self, **kwargs):
"""Turn script off."""

View File

@ -135,7 +135,6 @@ CONF_PREFIX = "prefix"
CONF_PROFILE_NAME = "profile_name"
CONF_PROTOCOL = "protocol"
CONF_PROXY_SSL = "proxy_ssl"
CONF_QUEUE_SIZE = "queue_size"
CONF_QUOTE = "quote"
CONF_RADIUS = "radius"
CONF_RECIPIENT = "recipient"

View File

@ -1,12 +1,10 @@
"""Helpers to execute scripts."""
from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from functools import partial
import itertools
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from async_timeout import timeout
import voluptuous as vol
@ -27,7 +25,6 @@ from homeassistant.const import (
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_REPEAT,
CONF_SCENE,
CONF_SEQUENCE,
@ -35,25 +32,15 @@ from homeassistant.const import (
CONF_UNTIL,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.core import (
CALLBACK_TYPE,
SERVICE_CALL_LIMIT,
Context,
HomeAssistant,
callback,
)
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
from homeassistant.helpers import (
condition,
config_validation as cv,
template as template,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_template,
)
from homeassistant.helpers.event import async_track_template
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
@ -64,74 +51,41 @@ from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
SCRIPT_MODE_ERROR = "error"
SCRIPT_MODE_IGNORE = "ignore"
SCRIPT_MODE_LEGACY = "legacy"
SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUE = "queue"
SCRIPT_MODE_RESTART = "restart"
SCRIPT_MODE_SINGLE = "single"
SCRIPT_MODE_CHOICES = [
SCRIPT_MODE_ERROR,
SCRIPT_MODE_IGNORE,
SCRIPT_MODE_LEGACY,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUE,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_SINGLE
DEFAULT_QUEUE_SIZE = 10
CONF_MAX = "max"
DEFAULT_MAX = 10
_LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
SCRIPT_BASE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_MODE): vol.In(SCRIPT_MODE_CHOICES),
vol.Optional(CONF_QUEUE_SIZE): vol.All(vol.Coerce(int), vol.Range(min=2)),
}
)
_UNSUPPORTED_IN_LEGACY = {
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
}
def warn_deprecated_legacy(logger, msg):
"""Warn about deprecated legacy mode."""
logger.warning(
"Script behavior has changed. "
"To continue using previous behavior, which is now deprecated, "
"add '%s: %s' to %s.",
CONF_MODE,
SCRIPT_MODE_LEGACY,
msg,
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
"""Make a schema for a component that uses the script helper."""
return vol.Schema(
{
**schema,
vol.Optional(CONF_MODE, default=default_script_mode): vol.In(
SCRIPT_MODE_CHOICES
),
vol.Optional(CONF_MAX, default=DEFAULT_MAX): vol.All(
vol.Coerce(int), vol.Range(min=2)
),
},
extra=extra,
)
def validate_legacy_mode_actions(sequence):
"""Check for actions not supported in legacy mode."""
for action in sequence:
script_action = cv.determine_script_action(action)
if script_action in _UNSUPPORTED_IN_LEGACY:
raise vol.Invalid(
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
f"{SCRIPT_MODE_LEGACY} mode"
)
def validate_queue_size(config):
"""Validate queue_size option."""
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
queue_size = config.get(CONF_QUEUE_SIZE)
if mode == SCRIPT_MODE_QUEUE:
if queue_size is None:
config[CONF_QUEUE_SIZE] = DEFAULT_QUEUE_SIZE
elif queue_size is not None:
raise vol.Invalid(f"{CONF_QUEUE_SIZE} not valid with {mode} {CONF_MODE}")
return config
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
@ -159,20 +113,8 @@ class _StopScript(Exception):
"""Throw if script needs to stop."""
class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class AlreadyRunning(exceptions.HomeAssistantError):
"""Throw if script already running and user wants error."""
class QueueFull(exceptions.HomeAssistantError):
"""Throw if script already running, user wants new run queued, but queue is full."""
class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run."""
class _ScriptRun:
"""Manage Script sequence run."""
def __init__(
self,
@ -189,17 +131,36 @@ class _ScriptRunBase(ABC):
self._log_exceptions = log_exceptions
self._step = -1
self._action: Optional[Dict[str, Any]] = None
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
def _changed(self):
self._script._changed() # pylint: disable=protected-access
if not self._stop.is_set():
self._script._changed() # pylint: disable=protected-access
@property
def _config_cache(self):
return self._script._config_cache # pylint: disable=protected-access
@abstractmethod
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
async def async_run(self) -> None:
"""Run script."""
try:
if self._stop.is_set():
return
self._script.last_triggered = utcnow()
self._changed()
self._log("Running script")
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(log_exceptions=False)
except _StopScript:
pass
finally:
self._finish()
async def _async_step(self, log_exceptions):
try:
@ -207,15 +168,23 @@ class _ScriptRunBase(ABC):
self, f"_async_{cv.determine_script_action(self._action)}_step"
)()
except Exception as ex:
if not isinstance(
ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
) and (self._log_exceptions or log_exceptions):
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(ex)
raise
@abstractmethod
def _finish(self):
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None
self._changed()
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
await self._stopped.wait()
def _log_exception(self, exception):
action_type = cv.determine_script_action(self._action)
@ -235,12 +204,6 @@ class _ScriptRunBase(ABC):
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
elif isinstance(exception, AlreadyRunning):
error_desc = "Already running"
elif isinstance(exception, QueueFull):
error_desc = "Run queue is full"
else:
error_desc = "Unexpected error"
level = _LOG_EXCEPTION
@ -254,11 +217,8 @@ class _ScriptRunBase(ABC):
level=level,
)
@abstractmethod
async def _async_delay_step(self):
"""Handle delay."""
def _prep_delay_step(self):
try:
delay = vol.All(cv.time_period, cv.positive_timedelta)(
template.render_complex(self._action[CONF_DELAY], self._variables)
@ -275,35 +235,128 @@ class _ScriptRunBase(ABC):
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
return delay
delay = delay.total_seconds()
self._changed()
try:
async with timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass
@abstractmethod
async def _async_wait_template_step(self):
"""Handle a wait template."""
def _prep_wait_template_step(self, async_script_wait):
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action)
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
# check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables):
return None
return
return async_track_template(
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
done.set()
unsub = async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
@abstractmethod
self._changed()
try:
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
delay = None
done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay):
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
for task in tasks:
task.cancel()
unsub()
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
# Stop long task and wait for it to finish.
long_task.cancel()
try:
await long_task
except Exception: # pylint: disable=broad-except
pass
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
long_task.result()
else:
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
def _prep_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
return async_prepare_call_from_config(self._hass, self._action, self._variables)
domain, service, service_data = async_prepare_call_from_config(
self._hass, self._action, self._variables
)
running_script = (
domain == "automation"
and service == "trigger"
or domain in ("python_script", "script")
)
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
if running_script:
limit = None
else:
limit = SERVICE_CALL_LIMIT
service_task = self._hass.async_create_task(
self._hass.services.async_call(
domain,
service,
service_data,
blocking=True,
context=self._context,
limit=limit,
)
)
if limit is not None:
# There is a call limit, so just wait for it to finish.
await service_task
return
await self._async_run_long_action(service_task)
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
@ -370,175 +423,6 @@ class _ScriptRunBase(ABC):
if not check:
raise _StopScript
@abstractmethod
async def _async_repeat_step(self):
"""Repeat a sequence."""
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
def _changed(self):
if not self._stop.is_set():
super()._changed()
async def async_run(self) -> None:
"""Run script."""
try:
if self._stop.is_set():
return
self._script.last_triggered = utcnow()
self._changed()
self._log("Running script")
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(log_exceptions=False)
except _StopScript:
pass
finally:
self._finish()
def _finish(self):
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None
self._changed()
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
await self._stopped.wait()
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step().total_seconds()
self._changed()
try:
async with timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
done.set()
unsub = self._prep_wait_template_step(async_script_wait)
if not unsub:
return
self._changed()
try:
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
delay = None
done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay):
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
for task in tasks:
task.cancel()
unsub()
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
# Stop long task and wait for it to finish.
long_task.cancel()
try:
await long_task
except Exception: # pylint: disable=broad-except
pass
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
long_task.result()
else:
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
running_script = (
domain == "automation"
and service == "trigger"
or domain == "python_script"
or domain == "script"
and service != SERVICE_TURN_OFF
)
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
if running_script:
limit = None
else:
limit = SERVICE_CALL_LIMIT
service_task = self._hass.async_create_task(
self._hass.services.async_call(
domain,
service,
service_data,
blocking=True,
context=self._context,
limit=limit,
)
)
if limit is not None:
# There is a call limit, so just wait for it to finish.
await service_task
return
await self._async_run_long_action(service_task)
async def _async_repeat_step(self):
"""Repeat a sequence."""
@ -638,165 +522,12 @@ class _QueuedScriptRun(_ScriptRun):
def _finish(self):
# pylint: disable=protected-access
self._script._queue_len -= 1
if self.lock_acquired:
self._script._queue_lck.release()
self.lock_acquired = False
super()._finish()
class _LegacyScriptRun(_ScriptRunBase):
"""Manage legacy Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
shared: Optional["_LegacyScriptRun"],
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
if shared:
self._shared = shared
else:
# To implement legacy behavior we need to share the following "run state"
# amongst all runs, so it will only exist in the first instantiation of
# concurrent runs, and the rest will use it, too.
self._current = -1
self._async_listeners: List[CALLBACK_TYPE] = []
self._shared = self
@property
def _cur(self):
return self._shared._current # pylint: disable=protected-access
@_cur.setter
def _cur(self, value):
self._shared._current = value # pylint: disable=protected-access
@property
def _async_listener(self):
return self._shared._async_listeners # pylint: disable=protected-access
async def async_run(self) -> None:
"""Run script."""
await self._async_run()
async def _async_run(self, propagate_exceptions=True):
if self._cur == -1:
self._script.last_triggered = utcnow()
self._log("Running script")
self._cur = 0
# Unregister callback if we were in a delay or wait but turn on is
# called again. In that case we just continue execution.
self._async_remove_listener()
suspended = False
try:
for self._step, self._action in itertools.islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(log_exceptions=not propagate_exceptions)
except _StopScript:
pass
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = self._step + 1
suspended = True
return
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
_cur_was = self._cur
if not suspended:
self._script.last_action = None
await self.async_stop()
if _cur_was != -1:
self._changed()
async def async_stop(self) -> None:
"""Stop script run."""
if self._cur == -1:
return
self._cur = -1
self._async_remove_listener()
self._script._runs.clear() # pylint: disable=protected-access
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step()
@callback
def async_script_delay(now):
"""Handle delay."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self._hass.async_create_task(self._async_run(False))
unsub = async_track_point_in_utc_time(
self._hass, async_script_delay, utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self._hass.async_create_task(self._async_run(False))
@callback
def async_script_timeout(now):
"""Call after timeout has expired."""
with suppress(ValueError):
self._async_listener.remove(unsub_timeout)
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
self._hass.async_create_task(self.async_stop())
unsub_wait = self._prep_wait_template_step(async_script_wait)
if not unsub_wait:
return
self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action:
unsub_timeout = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
)
self._async_listener.append(unsub_timeout)
raise _SuspendScript
async def _async_call_service_step(self):
"""Call the service specified in the action."""
await self._hass.services.async_call(
*self._prep_call_service_step(), blocking=True, context=self._context
)
async def _async_repeat_step(self):
"""Repeat a sequence."""
# Not supported in legacy mode.
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
unsub()
self._async_listener.clear()
class Script:
"""Representation of a script."""
@ -807,7 +538,7 @@ class Script:
name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
script_mode: str = DEFAULT_SCRIPT_MODE,
queue_size: int = DEFAULT_QUEUE_SIZE,
max_runs: int = DEFAULT_MAX,
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
) -> None:
@ -829,10 +560,7 @@ class Script:
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.can_cancel = not self.is_legacy or any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence
)
self.can_cancel = True
self._repeat_script = {}
for step, action in enumerate(sequence):
@ -850,10 +578,9 @@ class Script:
)
self._repeat_script[step] = sub_script
self._runs: List[_ScriptRunBase] = []
self._runs: List[_ScriptRun] = []
self._max_runs = max_runs
if script_mode == SCRIPT_MODE_QUEUE:
self._queue_size = queue_size
self._queue_len = 0
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._referenced_entities: Optional[Set[str]] = None
@ -873,11 +600,6 @@ class Script:
"""Return true if script is on."""
return len(self._runs) > 0
@property
def is_legacy(self) -> bool:
"""Return if using legacy mode."""
return self._script_mode == SCRIPT_MODE_LEGACY
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
@ -945,60 +667,40 @@ class Script:
) -> None:
"""Run script."""
if self.is_running:
if self._script_mode == SCRIPT_MODE_IGNORE:
self._log("Skipping script")
if self._script_mode == SCRIPT_MODE_SINGLE:
self._log("Already running", level=logging.WARNING)
return
if self._script_mode == SCRIPT_MODE_RESTART:
self._log("Restarting")
await self.async_stop(update_state=False)
elif len(self._runs) == self._max_runs:
self._log("Maximum number of runs exceeded", level=logging.WARNING)
return
if self._script_mode == SCRIPT_MODE_ERROR:
raise AlreadyRunning
if self._script_mode == SCRIPT_MODE_RESTART:
self._log("Restarting script")
await self.async_stop(update_state=False)
elif self._script_mode == SCRIPT_MODE_QUEUE:
self._log(
"Queueing script behind %i run%s",
self._queue_len,
"s" if self._queue_len > 1 else "",
)
if self._queue_len >= self._queue_size:
raise QueueFull
if self.is_legacy:
if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else:
shared = None
run: _ScriptRunBase = _LegacyScriptRun(
self._hass, self, variables, context, self._log_exceptions, shared
)
if self._script_mode != SCRIPT_MODE_QUEUE:
cls = _ScriptRun
else:
if self._script_mode != SCRIPT_MODE_QUEUE:
cls = _ScriptRun
else:
cls = _QueuedScriptRun
self._queue_len += 1
run = cls(self._hass, self, variables, context, self._log_exceptions)
cls = _QueuedScriptRun
run = cls(self._hass, self, variables, context, self._log_exceptions)
self._runs.append(run)
try:
if self.is_legacy:
await run.async_run()
else:
await asyncio.shield(run.async_run())
await asyncio.shield(run.async_run())
except asyncio.CancelledError:
await run.async_stop()
self._changed()
raise
async def async_stop(self, update_state: bool = True) -> None:
"""Stop running script."""
if not self.is_running:
return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
async def _async_stop(self, update_state):
await asyncio.wait([run.async_stop() for run in self._runs])
if update_state:
self._changed()
async def async_stop(self, update_state: bool = True) -> None:
"""Stop running script."""
if self.is_running:
await asyncio.shield(self._async_stop(update_state))
def _log(self, msg, *args, level=logging.INFO):
if self.name:
msg = f"%s: {msg}"

View File

@ -1,6 +1,4 @@
"""The tests for the automation component."""
from datetime import timedelta
import pytest
from homeassistant.components import logbook
@ -23,12 +21,7 @@ from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from tests.async_mock import Mock, patch
from tests.common import (
assert_setup_component,
async_fire_time_changed,
async_mock_service,
mock_restore_cache,
)
from tests.common import assert_setup_component, async_mock_service, mock_restore_cache
from tests.components.automation import common
from tests.components.logbook.test_init import MockLazyEventPartialState
@ -87,57 +80,6 @@ async def test_service_specify_data(hass, calls):
assert state.attributes.get("last_triggered") == time
async def test_action_delay(hass, calls):
"""Test action delay."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [
{
"service": "test.automation",
"data_template": {
"some": "{{ trigger.platform }} - "
"{{ trigger.event.event_type }}"
},
},
{"delay": {"minutes": "10"}},
{
"service": "test.automation",
"data_template": {
"some": "{{ trigger.platform }} - "
"{{ trigger.event.event_type }}"
},
},
],
}
},
)
time = dt_util.utcnow()
with patch("homeassistant.components.automation.utcnow", return_value=time):
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["some"] == "event - test_event"
future = dt_util.utcnow() + timedelta(minutes=10)
async_fire_time_changed(hass, future)
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[1].data["some"] == "event - test_event"
state = hass.states.get("automation.hello")
assert state is not None
assert state.attributes.get("last_triggered") == time
async def test_service_specify_entity_id(hass, calls):
"""Test service data."""
assert await async_setup_component(
@ -1070,42 +1012,3 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["domain"] == "automation"
assert event2["message"] == "has been triggered"
assert event2["entity_id"] == "automation.bye"
invalid_configs = [
{
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
},
{
"mode": "legacy",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [{"repeat": {"count": 5, "sequence": []}}],
},
]
@pytest.mark.parametrize("value", invalid_configs)
async def test_invalid_configs(hass, value):
"""Test invalid configurations."""
with assert_setup_component(0, automation.DOMAIN):
assert await async_setup_component(
hass, automation.DOMAIN, {automation.DOMAIN: value}
)
async def test_config_legacy(hass, caplog):
"""Test config defaulting to legacy mode."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
}
},
)
assert "To continue using previous behavior, which is now deprecated" in caplog.text

View File

@ -961,10 +961,8 @@ async def test_wait_template_with_trigger(hass, calls):
await hass.async_block_till_done()
hass.states.async_set("test.entity", "12")
await hass.async_block_till_done()
hass.states.async_set("test.entity", "8")
await hass.async_block_till_done()
await hass.async_block_till_done()
assert len(calls) == 1
assert "numeric_state - test.entity - 12" == calls[0].data["some"]

View File

@ -704,7 +704,6 @@ async def test_wait_template_with_trigger(hass, calls):
await hass.async_block_till_done()
hass.states.async_set("test.entity", "world")
await hass.async_block_till_done()
hass.states.async_set("test.entity", "hello")
await hass.async_block_till_done()
assert len(calls) == 1

View File

@ -5,7 +5,7 @@ from unittest import mock
import pytest
import homeassistant.components.automation as automation
from homeassistant.core import Context
from homeassistant.core import Context, callback
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -435,6 +435,7 @@ async def test_wait_template_with_trigger(hass, calls):
"value_template": "{{ states.test.entity.state == 'world' }}",
},
"action": [
{"event": "test_event"},
{"wait_template": "{{ is_state(trigger.entity_id, 'hello') }}"},
{
"service": "test.automation",
@ -458,10 +459,14 @@ async def test_wait_template_with_trigger(hass, calls):
await hass.async_block_till_done()
@callback
def event_handler(event):
hass.states.async_set("test.entity", "hello")
hass.bus.async_listen_once("test_event", event_handler)
hass.states.async_set("test.entity", "world")
await hass.async_block_till_done()
hass.states.async_set("test.entity", "hello")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["some"] == "template - test.entity - hello - world - None"

View File

@ -17,6 +17,7 @@ from homeassistant.const import (
)
from homeassistant.core import Context, callback, split_entity_id
from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import bind_hass
from homeassistant.setup import async_setup_component, setup_component
@ -80,75 +81,6 @@ class TestScriptComponent(unittest.TestCase):
"""Stop down everything that was started."""
self.hass.stop()
def test_turn_on_service(self):
"""Verify that the turn_on service."""
event = "test_event"
events = []
@callback
def record_event(event):
"""Add recorded event to set."""
events.append(event)
self.hass.bus.listen(event, record_event)
assert setup_component(
self.hass,
"script",
{
"script": {
"test": {"sequence": [{"delay": {"seconds": 5}}, {"event": event}]}
}
},
)
turn_on(self.hass, ENTITY_ID)
self.hass.block_till_done()
assert script.is_on(self.hass, ENTITY_ID)
assert 0 == len(events)
# Calling turn_on a second time should not advance the script
turn_on(self.hass, ENTITY_ID)
self.hass.block_till_done()
assert 0 == len(events)
turn_off(self.hass, ENTITY_ID)
self.hass.block_till_done()
assert not script.is_on(self.hass, ENTITY_ID)
assert 0 == len(events)
def test_toggle_service(self):
"""Test the toggling of a service."""
event = "test_event"
events = []
@callback
def record_event(event):
"""Add recorded event to set."""
events.append(event)
self.hass.bus.listen(event, record_event)
assert setup_component(
self.hass,
"script",
{
"script": {
"test": {"sequence": [{"delay": {"seconds": 5}}, {"event": event}]}
}
},
)
toggle(self.hass, ENTITY_ID)
self.hass.block_till_done()
assert script.is_on(self.hass, ENTITY_ID)
assert 0 == len(events)
toggle(self.hass, ENTITY_ID)
self.hass.block_till_done()
assert not script.is_on(self.hass, ENTITY_ID)
assert 0 == len(events)
def test_passing_variables(self):
"""Test different ways of passing in variables."""
calls = []
@ -195,17 +127,58 @@ class TestScriptComponent(unittest.TestCase):
assert calls[1].data["hello"] == "universe"
@pytest.mark.parametrize("toggle", [False, True])
async def test_turn_on_off_toggle(hass, toggle):
"""Verify turn_on, turn_off & toggle services."""
event = "test_event"
event_mock = Mock()
hass.bus.async_listen(event, event_mock)
was_on = False
@callback
def state_listener(entity_id, old_state, new_state):
nonlocal was_on
was_on = True
async_track_state_change(hass, ENTITY_ID, state_listener, to_state="on")
if toggle:
turn_off_step = {"service": "script.toggle", "entity_id": ENTITY_ID}
else:
turn_off_step = {"service": "script.turn_off", "entity_id": ENTITY_ID}
assert await async_setup_component(
hass,
"script",
{
"script": {
"test": {
"sequence": [{"event": event}, turn_off_step, {"event": event}]
}
}
},
)
assert not script.is_on(hass, ENTITY_ID)
if toggle:
await hass.services.async_call(
DOMAIN, SERVICE_TOGGLE, {ATTR_ENTITY_ID: ENTITY_ID}
)
else:
await hass.services.async_call(DOMAIN, split_entity_id(ENTITY_ID)[1])
await hass.async_block_till_done()
assert not script.is_on(hass, ENTITY_ID)
assert was_on
assert 1 == event_mock.call_count
invalid_configs = [
{"test": {}},
{"test hello world": {"sequence": [{"event": "bla"}]}},
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
{
"test": {
"mode": "legacy",
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
}
},
]
@ -222,8 +195,29 @@ async def test_setup_with_invalid_configs(hass, value):
@pytest.mark.parametrize("running", ["no", "same", "different"])
async def test_reload_service(hass, running):
"""Verify the reload service."""
event = "test_event"
event_flag = asyncio.Event()
@callback
def event_handler(event):
event_flag.set()
hass.bus.async_listen_once(event, event_handler)
hass.states.async_set("test.script", "off")
assert await async_setup_component(
hass, "script", {"script": {"test": {"sequence": [{"delay": {"seconds": 5}}]}}}
hass,
"script",
{
"script": {
"test": {
"sequence": [
{"event": event},
{"wait_template": "{{ is_state('test.script', 'on') }}"},
]
}
}
},
)
assert hass.states.get(ENTITY_ID) is not None
@ -232,7 +226,7 @@ async def test_reload_service(hass, running):
if running != "no":
_, object_id = split_entity_id(ENTITY_ID)
await hass.services.async_call(DOMAIN, object_id)
await hass.async_block_till_done()
await asyncio.wait_for(event_flag.wait(), 1)
assert script.is_on(hass, ENTITY_ID)
@ -486,14 +480,6 @@ async def test_config_basic(hass):
assert test_script.attributes["icon"] == "mdi:party"
async def test_config_legacy(hass, caplog):
"""Test config defaulting to legacy mode."""
assert await async_setup_component(
hass, "script", {"script": {"test_script": {"sequence": []}}}
)
assert "To continue using previous behavior, which is now deprecated" in caplog.text
async def test_logbook_humanify_script_started_event(hass):
"""Test humanifying script started event."""
hass.config.components.add("recorder")

View File

@ -27,8 +27,6 @@ from tests.common import (
ENTITY_ID = "script.test"
_BASIC_SCRIPT_MODES = ("legacy", "parallel")
@pytest.fixture
def mock_timeout(hass, monkeypatch):
@ -86,18 +84,16 @@ def async_watch_for_action(script_obj, message):
return flag
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_firing_event_basic(hass, script_mode):
async def test_firing_event_basic(hass):
"""Test the firing of events."""
event = "test_event"
context = Context()
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.is_legacy == (script_mode == "legacy")
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
await script_obj.async_run(context=context)
await hass.async_block_till_done()
@ -105,11 +101,10 @@ async def test_firing_event_basic(hass, script_mode):
assert len(events) == 1
assert events[0].context is context
assert events[0].data.get("hello") == "world"
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_firing_event_template(hass, script_mode):
async def test_firing_event_template(hass):
"""Test the firing of events."""
event = "test_event"
context = Context()
@ -128,9 +123,9 @@ async def test_firing_event_template(hass, script_mode):
},
}
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
await script_obj.async_run({"is_world": "yes"}, context=context)
await hass.async_block_till_done()
@ -143,16 +138,15 @@ async def test_firing_event_template(hass, script_mode):
}
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_calling_service_basic(hass, script_mode):
async def test_calling_service_basic(hass):
"""Test the calling of a service."""
context = Context()
calls = async_mock_service(hass, "test", "script")
sequence = cv.SCRIPT_SCHEMA({"service": "test.script", "data": {"hello": "world"}})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
await script_obj.async_run(context=context)
await hass.async_block_till_done()
@ -162,8 +156,7 @@ async def test_calling_service_basic(hass, script_mode):
assert calls[0].data.get("hello") == "world"
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_calling_service_template(hass, script_mode):
async def test_calling_service_template(hass):
"""Test the calling of a service."""
context = Context()
calls = async_mock_service(hass, "test", "script")
@ -187,9 +180,9 @@ async def test_calling_service_template(hass, script_mode):
},
}
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
await script_obj.async_run({"is_world": "yes"}, context=context)
await hass.async_block_till_done()
@ -199,8 +192,7 @@ async def test_calling_service_template(hass, script_mode):
assert calls[0].data.get("hello") == "world"
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_multiple_runs_no_wait(hass, script_mode):
async def test_multiple_runs_no_wait(hass):
"""Test multiple runs with no wait in script."""
logger = logging.getLogger("TEST")
calls = []
@ -243,7 +235,7 @@ async def test_multiple_runs_no_wait(hass, script_mode):
},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
# Start script twice in such a way that second run will be started while first run
# is in the middle of the first service call.
@ -267,16 +259,15 @@ async def test_multiple_runs_no_wait(hass, script_mode):
assert len(calls) == 4
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_activating_scene(hass, script_mode):
async def test_activating_scene(hass):
"""Test the activation of a scene."""
context = Context()
calls = async_mock_service(hass, scene.DOMAIN, SERVICE_TURN_ON)
sequence = cv.SCRIPT_SCHEMA({"scene": "scene.hello"})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
await script_obj.async_run(context=context)
await hass.async_block_till_done()
@ -287,8 +278,7 @@ async def test_activating_scene(hass, script_mode):
@pytest.mark.parametrize("count", [1, 3])
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_stop_no_wait(hass, caplog, script_mode, count):
async def test_stop_no_wait(hass, caplog, count):
"""Test stopping script."""
service_started_sem = asyncio.Semaphore(0)
finish_service_event = asyncio.Event()
@ -303,7 +293,7 @@ async def test_stop_no_wait(hass, caplog, script_mode, count):
hass.services.async_register("test", "script", async_simulate_long_service)
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=count)
# Get script started specified number of times and wait until the test.script
# service has started for each run.
@ -328,15 +318,14 @@ async def test_stop_no_wait(hass, caplog, script_mode, count):
assert script_was_runing
assert were_no_events
assert not script_obj.is_running
assert len(events) == (count if script_mode == "legacy" else 0)
assert len(events) == 0
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_delay_basic(hass, mock_timeout, script_mode):
async def test_delay_basic(hass, mock_timeout):
"""Test the delay."""
delay_alias = "delay step"
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
delay_started_flag = async_watch_for_action(script_obj, delay_alias)
assert script_obj.can_cancel
@ -358,8 +347,7 @@ async def test_delay_basic(hass, mock_timeout, script_mode):
assert script_obj.last_action is None
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
async def test_multiple_runs_delay(hass, mock_timeout):
"""Test multiple runs with delay in script."""
event = "test_event"
events = async_capture_events(hass, event)
@ -371,7 +359,7 @@ async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
{"event": event, "event_data": {"value": 2}},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
delay_started_flag = async_watch_for_action(script_obj, "delay")
try:
@ -386,31 +374,24 @@ async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
raise
else:
# Start second run of script while first run is in a delay.
if script_mode == "legacy":
await script_obj.async_run()
else:
script_obj.sequence[1]["alias"] = "delay run 2"
delay_started_flag = async_watch_for_action(script_obj, "delay run 2")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(delay_started_flag.wait(), 1)
script_obj.sequence[1]["alias"] = "delay run 2"
delay_started_flag = async_watch_for_action(script_obj, "delay run 2")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(delay_started_flag.wait(), 1)
async_fire_time_changed(hass, dt_util.utcnow() + delay)
await hass.async_block_till_done()
assert not script_obj.is_running
if script_mode == "legacy":
assert len(events) == 2
else:
assert len(events) == 4
assert events[-3].data["value"] == 1
assert events[-2].data["value"] == 2
assert len(events) == 4
assert events[-3].data["value"] == 1
assert events[-2].data["value"] == 2
assert events[-1].data["value"] == 2
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_delay_template_ok(hass, mock_timeout, script_mode):
async def test_delay_template_ok(hass, mock_timeout):
"""Test the delay as a template."""
sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
delay_started_flag = async_watch_for_action(script_obj, "delay")
assert script_obj.can_cancel
@ -430,8 +411,7 @@ async def test_delay_template_ok(hass, mock_timeout, script_mode):
assert not script_obj.is_running
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_delay_template_invalid(hass, caplog, script_mode):
async def test_delay_template_invalid(hass, caplog):
"""Test the delay as a template that fails."""
event = "test_event"
events = async_capture_events(hass, event)
@ -443,7 +423,7 @@ async def test_delay_template_invalid(hass, caplog, script_mode):
{"event": event},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
start_idx = len(caplog.records)
await script_obj.async_run()
@ -458,11 +438,10 @@ async def test_delay_template_invalid(hass, caplog, script_mode):
assert len(events) == 1
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_delay_template_complex_ok(hass, mock_timeout, script_mode):
async def test_delay_template_complex_ok(hass, mock_timeout):
"""Test the delay with a working complex template."""
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
delay_started_flag = async_watch_for_action(script_obj, "delay")
assert script_obj.can_cancel
@ -481,8 +460,7 @@ async def test_delay_template_complex_ok(hass, mock_timeout, script_mode):
assert not script_obj.is_running
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_delay_template_complex_invalid(hass, caplog, script_mode):
async def test_delay_template_complex_invalid(hass, caplog):
"""Test the delay with a complex template that fails."""
event = "test_event"
events = async_capture_events(hass, event)
@ -494,7 +472,7 @@ async def test_delay_template_complex_invalid(hass, caplog, script_mode):
{"event": event},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
start_idx = len(caplog.records)
await script_obj.async_run()
@ -509,13 +487,12 @@ async def test_delay_template_complex_invalid(hass, caplog, script_mode):
assert len(events) == 1
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_cancel_delay(hass, script_mode):
async def test_cancel_delay(hass):
"""Test the cancelling while the delay is present."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA([{"delay": {"seconds": 5}}, {"event": event}])
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
delay_started_flag = async_watch_for_action(script_obj, "delay")
try:
@ -541,8 +518,7 @@ async def test_cancel_delay(hass, script_mode):
assert len(events) == 0
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_wait_template_basic(hass, script_mode):
async def test_wait_template_basic(hass):
"""Test the wait template."""
wait_alias = "wait step"
sequence = cv.SCRIPT_SCHEMA(
@ -551,7 +527,7 @@ async def test_wait_template_basic(hass, script_mode):
"alias": wait_alias,
}
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
assert script_obj.can_cancel
@ -574,8 +550,7 @@ async def test_wait_template_basic(hass, script_mode):
assert script_obj.last_action is None
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_multiple_runs_wait_template(hass, script_mode):
async def test_multiple_runs_wait_template(hass):
"""Test multiple runs with wait_template in script."""
event = "test_event"
events = async_capture_events(hass, event)
@ -586,7 +561,7 @@ async def test_multiple_runs_wait_template(hass, script_mode):
{"event": event, "event_data": {"value": 2}},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
@ -602,25 +577,18 @@ async def test_multiple_runs_wait_template(hass, script_mode):
raise
else:
# Start second run of script while first run is in wait_template.
if script_mode == "legacy":
await script_obj.async_run()
else:
hass.async_create_task(script_obj.async_run())
hass.async_create_task(script_obj.async_run())
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
assert not script_obj.is_running
if script_mode == "legacy":
assert len(events) == 2
else:
assert len(events) == 4
assert events[-3].data["value"] == 1
assert events[-2].data["value"] == 2
assert len(events) == 4
assert events[-3].data["value"] == 1
assert events[-2].data["value"] == 2
assert events[-1].data["value"] == 2
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_cancel_wait_template(hass, script_mode):
async def test_cancel_wait_template(hass):
"""Test the cancelling while wait_template is present."""
event = "test_event"
events = async_capture_events(hass, event)
@ -630,7 +598,7 @@ async def test_cancel_wait_template(hass, script_mode):
{"event": event},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
@ -657,8 +625,7 @@ async def test_cancel_wait_template(hass, script_mode):
assert len(events) == 0
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_wait_template_not_schedule(hass, script_mode):
async def test_wait_template_not_schedule(hass):
"""Test the wait template with correct condition."""
event = "test_event"
events = async_capture_events(hass, event)
@ -669,7 +636,7 @@ async def test_wait_template_not_schedule(hass, script_mode):
{"event": event},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
hass.states.async_set("switch.test", "on")
await script_obj.async_run()
@ -679,13 +646,10 @@ async def test_wait_template_not_schedule(hass, script_mode):
assert len(events) == 2
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
@pytest.mark.parametrize(
"continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)]
)
async def test_wait_template_timeout(
hass, mock_timeout, continue_on_timeout, n_events, script_mode
):
async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_events):
"""Test the wait template, halt on timeout."""
event = "test_event"
events = async_capture_events(hass, event)
@ -696,7 +660,7 @@ async def test_wait_template_timeout(
if continue_on_timeout is not None:
sequence[0]["continue_on_timeout"] = continue_on_timeout
sequence = cv.SCRIPT_SCHEMA(sequence)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
@ -717,11 +681,10 @@ async def test_wait_template_timeout(
assert len(events) == n_events
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_wait_template_variables(hass, script_mode):
async def test_wait_template_variables(hass):
"""Test the wait template with variables."""
sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
wait_started_flag = async_watch_for_action(script_obj, "wait")
assert script_obj.can_cancel
@ -742,8 +705,7 @@ async def test_wait_template_variables(hass, script_mode):
assert not script_obj.is_running
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_condition_basic(hass, script_mode):
async def test_condition_basic(hass):
"""Test if we can use conditions in a script."""
event = "test_event"
events = async_capture_events(hass, event)
@ -757,9 +719,9 @@ async def test_condition_basic(hass, script_mode):
{"event": event},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.can_cancel == (script_mode != "legacy")
assert script_obj.can_cancel
hass.states.async_set("test.entity", "hello")
await script_obj.async_run()
@ -775,9 +737,8 @@ async def test_condition_basic(hass, script_mode):
assert len(events) == 3
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
@patch("homeassistant.helpers.script.condition.async_from_config")
async def test_condition_created_once(async_from_config, hass, script_mode):
async def test_condition_created_once(async_from_config, hass):
"""Test that the conditions do not get created multiple times."""
sequence = cv.SCRIPT_SCHEMA(
{
@ -785,7 +746,7 @@ async def test_condition_created_once(async_from_config, hass, script_mode):
"value_template": '{{ states.test.entity.state == "hello" }}',
}
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
async_from_config.reset_mock()
@ -798,8 +759,7 @@ async def test_condition_created_once(async_from_config, hass, script_mode):
assert len(script_obj._config_cache) == 1
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_condition_all_cached(hass, script_mode):
async def test_condition_all_cached(hass):
"""Test that multiple conditions get cached."""
sequence = cv.SCRIPT_SCHEMA(
[
@ -813,7 +773,7 @@ async def test_condition_all_cached(hass, script_mode):
},
]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
hass.states.async_set("test.entity", "hello")
await script_obj.async_run()
@ -843,7 +803,7 @@ async def test_repeat_count(hass):
}
}
)
script_obj = script.Script(hass, sequence, script_mode="ignore")
script_obj = script.Script(hass, sequence)
await script_obj.async_run()
await hass.async_block_till_done()
@ -887,7 +847,7 @@ async def test_repeat_conditional(hass, condition):
"condition": "template",
"value_template": "{{ is_state('sensor.test', 'done') }}",
}
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence))
wait_started = async_watch_for_action(script_obj, "wait")
hass.states.async_set("sensor.test", "1")
@ -917,12 +877,11 @@ async def test_repeat_conditional(hass, condition):
assert event.data.get("index") == str(index + 1)
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_last_triggered(hass, script_mode):
async def test_last_triggered(hass):
"""Test the last_triggered."""
event = "test_event"
sequence = cv.SCRIPT_SCHEMA({"event": event})
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
assert script_obj.last_triggered is None
@ -934,13 +893,12 @@ async def test_last_triggered(hass, script_mode):
assert script_obj.last_triggered == time
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_propagate_error_service_not_found(hass, script_mode):
async def test_propagate_error_service_not_found(hass):
"""Test that a script aborts when a service is not found."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
with pytest.raises(exceptions.ServiceNotFound):
await script_obj.async_run()
@ -949,8 +907,7 @@ async def test_propagate_error_service_not_found(hass, script_mode):
assert not script_obj.is_running
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_propagate_error_invalid_service_data(hass, script_mode):
async def test_propagate_error_invalid_service_data(hass):
"""Test that a script aborts when we send invalid service data."""
event = "test_event"
events = async_capture_events(hass, event)
@ -958,7 +915,7 @@ async def test_propagate_error_invalid_service_data(hass, script_mode):
sequence = cv.SCRIPT_SCHEMA(
[{"service": "test.script", "data": {"text": 1}}, {"event": event}]
)
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
with pytest.raises(vol.Invalid):
await script_obj.async_run()
@ -968,8 +925,7 @@ async def test_propagate_error_invalid_service_data(hass, script_mode):
assert not script_obj.is_running
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_propagate_error_service_exception(hass, script_mode):
async def test_propagate_error_service_exception(hass):
"""Test that a script aborts when a service throws an exception."""
event = "test_event"
events = async_capture_events(hass, event)
@ -982,7 +938,7 @@ async def test_propagate_error_service_exception(hass, script_mode):
hass.services.async_register("test", "script", record_call)
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
script_obj = script.Script(hass, sequence, script_mode=script_mode)
script_obj = script.Script(hass, sequence)
with pytest.raises(ValueError):
await script_obj.async_run()
@ -1053,15 +1009,8 @@ def does_not_raise():
yield
@pytest.mark.parametrize(
"script_mode,expectation,messages",
[
("ignore", does_not_raise(), ["Skipping"]),
("error", pytest.raises(exceptions.HomeAssistantError), []),
],
)
async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
"""Test overlapping runs with script_mode='ignore'."""
async def test_script_mode_single(hass, caplog):
"""Test overlapping runs with max_runs = 1."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
@ -1071,8 +1020,7 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
{"event": event, "event_data": {"value": 2}},
]
)
logger = logging.getLogger("TEST")
script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger)
script_obj = script.Script(hass, sequence)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
@ -1086,19 +1034,10 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
# Start second run of script while first run is suspended in wait_template.
with expectation:
await script_obj.async_run()
await script_obj.async_run()
assert "Already running" in caplog.text
assert script_obj.is_running
assert all(
any(
rec.levelname == "INFO"
and rec.name == "TEST"
and message in rec.message
for rec in caplog.records
)
for message in messages
)
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
@ -1116,7 +1055,7 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
[("restart", ["Restarting"], [2]), ("parallel", [], [2, 2])],
)
async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
"""Test overlapping runs with script_mode='restart'."""
"""Test overlapping runs with max_runs > 1."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
@ -1127,7 +1066,10 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
]
)
logger = logging.getLogger("TEST")
script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger)
max_runs = 1 if script_mode == "restart" else 2
script_obj = script.Script(
hass, sequence, script_mode=script_mode, max_runs=max_runs, logger=logger
)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
@ -1140,7 +1082,6 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
assert events[0].data["value"] == 1
# Start second run of script while first run is suspended in wait_template.
# This should stop first run then start a new run.
wait_started_flag.clear()
hass.async_create_task(script_obj.async_run())
@ -1172,7 +1113,7 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
async def test_script_mode_queue(hass):
"""Test overlapping runs with script_mode='queue'."""
"""Test overlapping runs with script_mode = 'queue' & max_runs > 1."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
@ -1184,7 +1125,9 @@ async def test_script_mode_queue(hass):
]
)
logger = logging.getLogger("TEST")
script_obj = script.Script(hass, sequence, script_mode="queue", logger=logger)
script_obj = script.Script(
hass, sequence, script_mode="queue", max_runs=2, logger=logger
)
wait_started_flag = async_watch_for_action(script_obj, "wait")
try: