mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Remove legacy script mode and simplify remaining modes (#37729)
This commit is contained in:
parent
8a8289b1a4
commit
63e55bff52
@ -15,7 +15,6 @@ from homeassistant.const import (
|
||||
CONF_ID,
|
||||
CONF_MODE,
|
||||
CONF_PLATFORM,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_ZONE,
|
||||
EVENT_HOMEASSISTANT_STARTED,
|
||||
SERVICE_RELOAD,
|
||||
@ -31,7 +30,12 @@ import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.script import SCRIPT_BASE_SCHEMA, Script, validate_queue_size
|
||||
from homeassistant.helpers.script import (
|
||||
CONF_MAX,
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
Script,
|
||||
make_script_schema,
|
||||
)
|
||||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.helpers.typing import TemplateVarsType
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -99,7 +103,7 @@ _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
|
||||
|
||||
PLATFORM_SCHEMA = vol.All(
|
||||
cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"),
|
||||
SCRIPT_BASE_SCHEMA.extend(
|
||||
make_script_schema(
|
||||
{
|
||||
# str on purpose
|
||||
CONF_ID: str,
|
||||
@ -110,9 +114,9 @@ PLATFORM_SCHEMA = vol.All(
|
||||
vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA,
|
||||
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
|
||||
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
|
||||
}
|
||||
},
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
),
|
||||
validate_queue_size,
|
||||
)
|
||||
|
||||
|
||||
@ -507,7 +511,7 @@ async def _async_process_config(hass, config, component):
|
||||
config_block[CONF_ACTION],
|
||||
name,
|
||||
script_mode=config_block[CONF_MODE],
|
||||
queue_size=config_block.get(CONF_QUEUE_SIZE, 0),
|
||||
max_runs=config_block[CONF_MAX],
|
||||
logger=_LOGGER,
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Config validation helper for the automation integration."""
|
||||
import asyncio
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -9,21 +8,14 @@ from homeassistant.components.device_automation.exceptions import (
|
||||
InvalidDeviceAutomationConfig,
|
||||
)
|
||||
from homeassistant.config import async_log_exception, config_without_domain
|
||||
from homeassistant.const import CONF_ALIAS, CONF_ID, CONF_MODE, CONF_PLATFORM
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import condition, config_per_platform
|
||||
from homeassistant.helpers.script import (
|
||||
SCRIPT_MODE_LEGACY,
|
||||
async_validate_action_config,
|
||||
validate_legacy_mode_actions,
|
||||
warn_deprecated_legacy,
|
||||
)
|
||||
from homeassistant.helpers.script import async_validate_action_config
|
||||
from homeassistant.loader import IntegrationNotFound
|
||||
|
||||
from . import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
||||
@ -56,9 +48,6 @@ async def async_validate_config_item(hass, config, full_config=None):
|
||||
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
|
||||
)
|
||||
|
||||
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||
validate_legacy_mode_actions(config[CONF_ACTION])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@ -78,35 +67,6 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
|
||||
return config
|
||||
|
||||
|
||||
def _deprecated_legacy_mode(config):
|
||||
legacy_names = []
|
||||
legacy_unnamed_found = False
|
||||
|
||||
for cfg in config[DOMAIN]:
|
||||
mode = cfg.get(CONF_MODE)
|
||||
if mode is None:
|
||||
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
|
||||
name = cfg.get(CONF_ID) or cfg.get(CONF_ALIAS)
|
||||
if name:
|
||||
legacy_names.append(name)
|
||||
else:
|
||||
legacy_unnamed_found = True
|
||||
|
||||
if legacy_names or legacy_unnamed_found:
|
||||
msgs = []
|
||||
if legacy_unnamed_found:
|
||||
msgs.append("unnamed automations")
|
||||
if legacy_names:
|
||||
if len(legacy_names) == 1:
|
||||
base_msg = "this automation"
|
||||
else:
|
||||
base_msg = "these automations"
|
||||
msgs.append(f"{base_msg}: {', '.join(legacy_names)}")
|
||||
warn_deprecated_legacy(_LOGGER, " and ".join(msgs))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
async def async_validate_config(hass, config):
|
||||
"""Validate config."""
|
||||
automations = list(
|
||||
@ -126,6 +86,4 @@ async def async_validate_config(hass, config):
|
||||
config = config_without_domain(config, DOMAIN)
|
||||
config[DOMAIN] = automations
|
||||
|
||||
_deprecated_legacy_mode(config)
|
||||
|
||||
return config
|
||||
|
@ -11,7 +11,6 @@ from homeassistant.const import (
|
||||
CONF_ALIAS,
|
||||
CONF_ICON,
|
||||
CONF_MODE,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_SEQUENCE,
|
||||
SERVICE_RELOAD,
|
||||
SERVICE_TOGGLE,
|
||||
@ -25,12 +24,10 @@ from homeassistant.helpers.config_validation import make_entity_service_schema
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.script import (
|
||||
SCRIPT_BASE_SCHEMA,
|
||||
SCRIPT_MODE_LEGACY,
|
||||
CONF_MAX,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
Script,
|
||||
validate_legacy_mode_actions,
|
||||
validate_queue_size,
|
||||
warn_deprecated_legacy,
|
||||
make_script_schema,
|
||||
)
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -52,51 +49,24 @@ ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
||||
EVENT_SCRIPT_STARTED = "script_started"
|
||||
|
||||
|
||||
def _deprecated_legacy_mode(config):
|
||||
legacy_scripts = []
|
||||
for object_id, cfg in config.items():
|
||||
mode = cfg.get(CONF_MODE)
|
||||
if mode is None:
|
||||
legacy_scripts.append(object_id)
|
||||
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
|
||||
if legacy_scripts:
|
||||
warn_deprecated_legacy(_LOGGER, f"script(s): {', '.join(legacy_scripts)}")
|
||||
return config
|
||||
|
||||
|
||||
def _not_supported_in_legacy_mode(config):
|
||||
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||
validate_legacy_mode_actions(config[CONF_SEQUENCE])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
SCRIPT_ENTRY_SCHEMA = vol.All(
|
||||
SCRIPT_BASE_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_ALIAS): cv.string,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
|
||||
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
|
||||
vol.Optional(CONF_FIELDS, default={}): {
|
||||
cv.string: {
|
||||
vol.Optional(CONF_DESCRIPTION): cv.string,
|
||||
vol.Optional(CONF_EXAMPLE): cv.string,
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
validate_queue_size,
|
||||
_not_supported_in_legacy_mode,
|
||||
SCRIPT_ENTRY_SCHEMA = make_script_schema(
|
||||
{
|
||||
vol.Optional(CONF_ALIAS): cv.string,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
|
||||
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
|
||||
vol.Optional(CONF_FIELDS, default={}): {
|
||||
cv.string: {
|
||||
vol.Optional(CONF_DESCRIPTION): cv.string,
|
||||
vol.Optional(CONF_EXAMPLE): cv.string,
|
||||
}
|
||||
},
|
||||
},
|
||||
SCRIPT_MODE_SINGLE,
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.All(
|
||||
cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA), _deprecated_legacy_mode
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
{DOMAIN: cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA)}, extra=vol.ALLOW_EXTRA
|
||||
)
|
||||
|
||||
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
||||
@ -192,29 +162,26 @@ async def async_setup(hass, config):
|
||||
"""Call a service to turn script on."""
|
||||
variables = service.data.get(ATTR_VARIABLES)
|
||||
for script_entity in await component.async_extract_from_service(service):
|
||||
if script_entity.script.is_legacy:
|
||||
await hass.services.async_call(
|
||||
DOMAIN, script_entity.object_id, variables, context=service.context
|
||||
)
|
||||
else:
|
||||
await script_entity.async_turn_on(
|
||||
variables=variables, context=service.context, wait=False
|
||||
)
|
||||
await script_entity.async_turn_on(
|
||||
variables=variables, context=service.context, wait=False
|
||||
)
|
||||
|
||||
async def turn_off_service(service):
|
||||
"""Cancel a script."""
|
||||
# Stopping a script is ok to be done in parallel
|
||||
scripts = await component.async_extract_from_service(service)
|
||||
script_entities = await component.async_extract_from_service(service)
|
||||
|
||||
if not scripts:
|
||||
if not script_entities:
|
||||
return
|
||||
|
||||
await asyncio.wait([script.async_turn_off() for script in scripts])
|
||||
await asyncio.wait(
|
||||
[script_entity.async_turn_off() for script_entity in script_entities]
|
||||
)
|
||||
|
||||
async def toggle_service(service):
|
||||
"""Toggle a script."""
|
||||
for script_entity in await component.async_extract_from_service(service):
|
||||
await script_entity.async_toggle(context=service.context)
|
||||
await script_entity.async_toggle(context=service.context, wait=False)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_RELOAD, reload_service, schema=RELOAD_SERVICE_SCHEMA
|
||||
@ -239,27 +206,14 @@ async def _async_process_config(hass, config, component):
|
||||
"""Execute a service call to script.<script name>."""
|
||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||
script_entity = component.get_entity(entity_id)
|
||||
if script_entity.script.is_legacy and script_entity.is_on:
|
||||
_LOGGER.warning("Script %s already running", entity_id)
|
||||
return
|
||||
await script_entity.async_turn_on(
|
||||
variables=service.data, context=service.context
|
||||
)
|
||||
|
||||
script_entities = []
|
||||
|
||||
for object_id, cfg in config.get(DOMAIN, {}).items():
|
||||
script_entities.append(
|
||||
ScriptEntity(
|
||||
hass,
|
||||
object_id,
|
||||
cfg.get(CONF_ALIAS, object_id),
|
||||
cfg.get(CONF_ICON),
|
||||
cfg[CONF_SEQUENCE],
|
||||
cfg[CONF_MODE],
|
||||
cfg.get(CONF_QUEUE_SIZE, 0),
|
||||
)
|
||||
)
|
||||
script_entities = [
|
||||
ScriptEntity(hass, object_id, cfg)
|
||||
for object_id, cfg in config.get(DOMAIN, {}).items()
|
||||
]
|
||||
|
||||
await component.async_add_entities(script_entities)
|
||||
|
||||
@ -289,18 +243,18 @@ class ScriptEntity(ToggleEntity):
|
||||
|
||||
icon = None
|
||||
|
||||
def __init__(self, hass, object_id, name, icon, sequence, mode, queue_size):
|
||||
def __init__(self, hass, object_id, cfg):
|
||||
"""Initialize the script."""
|
||||
self.object_id = object_id
|
||||
self.icon = icon
|
||||
self.icon = cfg.get(CONF_ICON)
|
||||
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
|
||||
self.script = Script(
|
||||
hass,
|
||||
sequence,
|
||||
name,
|
||||
cfg[CONF_SEQUENCE],
|
||||
cfg.get(CONF_ALIAS, object_id),
|
||||
self.async_change_listener,
|
||||
mode,
|
||||
queue_size,
|
||||
cfg[CONF_MODE],
|
||||
cfg[CONF_MAX],
|
||||
logging.getLogger(f"{__name__}.{object_id}"),
|
||||
)
|
||||
self._changed = asyncio.Event()
|
||||
@ -354,13 +308,10 @@ class ScriptEntity(ToggleEntity):
|
||||
|
||||
# Caller does not want to wait for called script to finish so let script run in
|
||||
# separate Task. However, wait for first state change so we can guarantee that
|
||||
# it is written to the State Machine before we return. Only do this for
|
||||
# non-legacy scripts, since legacy scripts don't necessarily change state
|
||||
# immediately.
|
||||
# it is written to the State Machine before we return.
|
||||
self._changed.clear()
|
||||
self.hass.async_create_task(coro)
|
||||
if not self.script.is_legacy:
|
||||
await self._changed.wait()
|
||||
await self._changed.wait()
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn script off."""
|
||||
|
@ -135,7 +135,6 @@ CONF_PREFIX = "prefix"
|
||||
CONF_PROFILE_NAME = "profile_name"
|
||||
CONF_PROTOCOL = "protocol"
|
||||
CONF_PROXY_SSL = "proxy_ssl"
|
||||
CONF_QUEUE_SIZE = "queue_size"
|
||||
CONF_QUOTE = "quote"
|
||||
CONF_RADIUS = "radius"
|
||||
CONF_RECIPIENT = "recipient"
|
||||
|
@ -1,12 +1,10 @@
|
||||
"""Helpers to execute scripts."""
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
from async_timeout import timeout
|
||||
import voluptuous as vol
|
||||
@ -27,7 +25,6 @@ from homeassistant.const import (
|
||||
CONF_EVENT_DATA,
|
||||
CONF_EVENT_DATA_TEMPLATE,
|
||||
CONF_MODE,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_REPEAT,
|
||||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
@ -35,25 +32,15 @@ from homeassistant.const import (
|
||||
CONF_UNTIL,
|
||||
CONF_WAIT_TEMPLATE,
|
||||
CONF_WHILE,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
SERVICE_CALL_LIMIT,
|
||||
Context,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
|
||||
from homeassistant.helpers import (
|
||||
condition,
|
||||
config_validation as cv,
|
||||
template as template,
|
||||
)
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_point_in_utc_time,
|
||||
async_track_template,
|
||||
)
|
||||
from homeassistant.helpers.event import async_track_template
|
||||
from homeassistant.helpers.service import (
|
||||
CONF_SERVICE_DATA,
|
||||
async_prepare_call_from_config,
|
||||
@ -64,74 +51,41 @@ from homeassistant.util.dt import utcnow
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
SCRIPT_MODE_ERROR = "error"
|
||||
SCRIPT_MODE_IGNORE = "ignore"
|
||||
SCRIPT_MODE_LEGACY = "legacy"
|
||||
SCRIPT_MODE_PARALLEL = "parallel"
|
||||
SCRIPT_MODE_QUEUE = "queue"
|
||||
SCRIPT_MODE_RESTART = "restart"
|
||||
SCRIPT_MODE_SINGLE = "single"
|
||||
SCRIPT_MODE_CHOICES = [
|
||||
SCRIPT_MODE_ERROR,
|
||||
SCRIPT_MODE_IGNORE,
|
||||
SCRIPT_MODE_LEGACY,
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUE,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
]
|
||||
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
|
||||
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_SINGLE
|
||||
|
||||
DEFAULT_QUEUE_SIZE = 10
|
||||
CONF_MAX = "max"
|
||||
DEFAULT_MAX = 10
|
||||
|
||||
_LOG_EXCEPTION = logging.ERROR + 1
|
||||
_TIMEOUT_MSG = "Timeout reached, abort script."
|
||||
|
||||
SCRIPT_BASE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_MODE): vol.In(SCRIPT_MODE_CHOICES),
|
||||
vol.Optional(CONF_QUEUE_SIZE): vol.All(vol.Coerce(int), vol.Range(min=2)),
|
||||
}
|
||||
)
|
||||
|
||||
_UNSUPPORTED_IN_LEGACY = {
|
||||
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
|
||||
}
|
||||
|
||||
|
||||
def warn_deprecated_legacy(logger, msg):
|
||||
"""Warn about deprecated legacy mode."""
|
||||
logger.warning(
|
||||
"Script behavior has changed. "
|
||||
"To continue using previous behavior, which is now deprecated, "
|
||||
"add '%s: %s' to %s.",
|
||||
CONF_MODE,
|
||||
SCRIPT_MODE_LEGACY,
|
||||
msg,
|
||||
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
|
||||
"""Make a schema for a component that uses the script helper."""
|
||||
return vol.Schema(
|
||||
{
|
||||
**schema,
|
||||
vol.Optional(CONF_MODE, default=default_script_mode): vol.In(
|
||||
SCRIPT_MODE_CHOICES
|
||||
),
|
||||
vol.Optional(CONF_MAX, default=DEFAULT_MAX): vol.All(
|
||||
vol.Coerce(int), vol.Range(min=2)
|
||||
),
|
||||
},
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def validate_legacy_mode_actions(sequence):
|
||||
"""Check for actions not supported in legacy mode."""
|
||||
for action in sequence:
|
||||
script_action = cv.determine_script_action(action)
|
||||
if script_action in _UNSUPPORTED_IN_LEGACY:
|
||||
raise vol.Invalid(
|
||||
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
|
||||
f"{SCRIPT_MODE_LEGACY} mode"
|
||||
)
|
||||
|
||||
|
||||
def validate_queue_size(config):
|
||||
"""Validate queue_size option."""
|
||||
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
|
||||
queue_size = config.get(CONF_QUEUE_SIZE)
|
||||
if mode == SCRIPT_MODE_QUEUE:
|
||||
if queue_size is None:
|
||||
config[CONF_QUEUE_SIZE] = DEFAULT_QUEUE_SIZE
|
||||
elif queue_size is not None:
|
||||
raise vol.Invalid(f"{CONF_QUEUE_SIZE} not valid with {mode} {CONF_MODE}")
|
||||
return config
|
||||
|
||||
|
||||
async def async_validate_action_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
@ -159,20 +113,8 @@ class _StopScript(Exception):
|
||||
"""Throw if script needs to stop."""
|
||||
|
||||
|
||||
class _SuspendScript(Exception):
|
||||
"""Throw if script needs to suspend."""
|
||||
|
||||
|
||||
class AlreadyRunning(exceptions.HomeAssistantError):
|
||||
"""Throw if script already running and user wants error."""
|
||||
|
||||
|
||||
class QueueFull(exceptions.HomeAssistantError):
|
||||
"""Throw if script already running, user wants new run queued, but queue is full."""
|
||||
|
||||
|
||||
class _ScriptRunBase(ABC):
|
||||
"""Common data & methods for managing Script sequence run."""
|
||||
class _ScriptRun:
|
||||
"""Manage Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -189,17 +131,36 @@ class _ScriptRunBase(ABC):
|
||||
self._log_exceptions = log_exceptions
|
||||
self._step = -1
|
||||
self._action: Optional[Dict[str, Any]] = None
|
||||
self._stop = asyncio.Event()
|
||||
self._stopped = asyncio.Event()
|
||||
|
||||
def _changed(self):
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
if not self._stop.is_set():
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _config_cache(self):
|
||||
return self._script._config_cache # pylint: disable=protected-access
|
||||
|
||||
@abstractmethod
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
try:
|
||||
if self._stop.is_set():
|
||||
return
|
||||
self._script.last_triggered = utcnow()
|
||||
self._changed()
|
||||
self._log("Running script")
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
if self._stop.is_set():
|
||||
break
|
||||
await self._async_step(log_exceptions=False)
|
||||
except _StopScript:
|
||||
pass
|
||||
finally:
|
||||
self._finish()
|
||||
|
||||
async def _async_step(self, log_exceptions):
|
||||
try:
|
||||
@ -207,15 +168,23 @@ class _ScriptRunBase(ABC):
|
||||
self, f"_async_{cv.determine_script_action(self._action)}_step"
|
||||
)()
|
||||
except Exception as ex:
|
||||
if not isinstance(
|
||||
ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
|
||||
) and (self._log_exceptions or log_exceptions):
|
||||
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
|
||||
self._log_exceptions or log_exceptions
|
||||
):
|
||||
self._log_exception(ex)
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
def _finish(self):
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
if not self._script.is_running:
|
||||
self._script.last_action = None
|
||||
self._changed()
|
||||
self._stopped.set()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
self._stop.set()
|
||||
await self._stopped.wait()
|
||||
|
||||
def _log_exception(self, exception):
|
||||
action_type = cv.determine_script_action(self._action)
|
||||
@ -235,12 +204,6 @@ class _ScriptRunBase(ABC):
|
||||
elif isinstance(exception, exceptions.ServiceNotFound):
|
||||
error_desc = "Service not found"
|
||||
|
||||
elif isinstance(exception, AlreadyRunning):
|
||||
error_desc = "Already running"
|
||||
|
||||
elif isinstance(exception, QueueFull):
|
||||
error_desc = "Run queue is full"
|
||||
|
||||
else:
|
||||
error_desc = "Unexpected error"
|
||||
level = _LOG_EXCEPTION
|
||||
@ -254,11 +217,8 @@ class _ScriptRunBase(ABC):
|
||||
level=level,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
|
||||
def _prep_delay_step(self):
|
||||
try:
|
||||
delay = vol.All(cv.time_period, cv.positive_timedelta)(
|
||||
template.render_complex(self._action[CONF_DELAY], self._variables)
|
||||
@ -275,35 +235,128 @@ class _ScriptRunBase(ABC):
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
|
||||
return delay
|
||||
delay = delay.total_seconds()
|
||||
self._changed()
|
||||
try:
|
||||
async with timeout(delay):
|
||||
await self._stop.wait()
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
def _prep_wait_template_step(self, async_script_wait):
|
||||
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self._hass
|
||||
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
|
||||
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self._hass
|
||||
|
||||
# check if condition already okay
|
||||
if condition.async_template(self._hass, wait_template, self._variables):
|
||||
return None
|
||||
return
|
||||
|
||||
return async_track_template(
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
done.set()
|
||||
|
||||
unsub = async_track_template(
|
||||
self._hass, wait_template, async_script_wait, self._variables
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
self._changed()
|
||||
try:
|
||||
delay = self._action[CONF_TIMEOUT].total_seconds()
|
||||
except KeyError:
|
||||
delay = None
|
||||
done = asyncio.Event()
|
||||
tasks = [
|
||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
||||
]
|
||||
try:
|
||||
async with timeout(delay):
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
raise _StopScript
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
unsub()
|
||||
|
||||
async def _async_run_long_action(self, long_task):
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task():
|
||||
# Stop long task and wait for it to finish.
|
||||
long_task.cancel()
|
||||
try:
|
||||
await long_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
# Wait for long task while monitoring for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_long_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
|
||||
if long_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if long_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
long_task.result()
|
||||
else:
|
||||
# Stopped before long task completed, so cancel it.
|
||||
await async_cancel_long_task()
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
|
||||
def _prep_call_service_step(self):
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
return async_prepare_call_from_config(self._hass, self._action, self._variables)
|
||||
|
||||
domain, service, service_data = async_prepare_call_from_config(
|
||||
self._hass, self._action, self._variables
|
||||
)
|
||||
|
||||
running_script = (
|
||||
domain == "automation"
|
||||
and service == "trigger"
|
||||
or domain in ("python_script", "script")
|
||||
)
|
||||
# If this might start a script then disable the call timeout.
|
||||
# Otherwise use the normal service call limit.
|
||||
if running_script:
|
||||
limit = None
|
||||
else:
|
||||
limit = SERVICE_CALL_LIMIT
|
||||
|
||||
service_task = self._hass.async_create_task(
|
||||
self._hass.services.async_call(
|
||||
domain,
|
||||
service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
# There is a call limit, so just wait for it to finish.
|
||||
await service_task
|
||||
return
|
||||
|
||||
await self._async_run_long_action(service_task)
|
||||
|
||||
async def _async_device_step(self):
|
||||
"""Perform the device automation specified in the action."""
|
||||
@ -370,175 +423,6 @@ class _ScriptRunBase(ABC):
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
@abstractmethod
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _ScriptRun(_ScriptRunBase):
|
||||
"""Manage Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: Optional[Sequence],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
) -> None:
|
||||
super().__init__(hass, script, variables, context, log_exceptions)
|
||||
self._stop = asyncio.Event()
|
||||
self._stopped = asyncio.Event()
|
||||
|
||||
def _changed(self):
|
||||
if not self._stop.is_set():
|
||||
super()._changed()
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
try:
|
||||
if self._stop.is_set():
|
||||
return
|
||||
self._script.last_triggered = utcnow()
|
||||
self._changed()
|
||||
self._log("Running script")
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
if self._stop.is_set():
|
||||
break
|
||||
await self._async_step(log_exceptions=False)
|
||||
except _StopScript:
|
||||
pass
|
||||
finally:
|
||||
self._finish()
|
||||
|
||||
def _finish(self):
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
if not self._script.is_running:
|
||||
self._script.last_action = None
|
||||
self._changed()
|
||||
self._stopped.set()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
self._stop.set()
|
||||
await self._stopped.wait()
|
||||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
delay = self._prep_delay_step().total_seconds()
|
||||
self._changed()
|
||||
try:
|
||||
async with timeout(delay):
|
||||
await self._stop.wait()
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
done.set()
|
||||
|
||||
unsub = self._prep_wait_template_step(async_script_wait)
|
||||
if not unsub:
|
||||
return
|
||||
|
||||
self._changed()
|
||||
try:
|
||||
delay = self._action[CONF_TIMEOUT].total_seconds()
|
||||
except KeyError:
|
||||
delay = None
|
||||
done = asyncio.Event()
|
||||
tasks = [
|
||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
||||
]
|
||||
try:
|
||||
async with timeout(delay):
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
raise _StopScript
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
unsub()
|
||||
|
||||
async def _async_run_long_action(self, long_task):
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task():
|
||||
# Stop long task and wait for it to finish.
|
||||
long_task.cancel()
|
||||
try:
|
||||
await long_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
# Wait for long task while monitoring for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_long_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
|
||||
if long_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if long_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
long_task.result()
|
||||
else:
|
||||
# Stopped before long task completed, so cancel it.
|
||||
await async_cancel_long_task()
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
domain, service, service_data = self._prep_call_service_step()
|
||||
|
||||
running_script = (
|
||||
domain == "automation"
|
||||
and service == "trigger"
|
||||
or domain == "python_script"
|
||||
or domain == "script"
|
||||
and service != SERVICE_TURN_OFF
|
||||
)
|
||||
# If this might start a script then disable the call timeout.
|
||||
# Otherwise use the normal service call limit.
|
||||
if running_script:
|
||||
limit = None
|
||||
else:
|
||||
limit = SERVICE_CALL_LIMIT
|
||||
|
||||
service_task = self._hass.async_create_task(
|
||||
self._hass.services.async_call(
|
||||
domain,
|
||||
service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
# There is a call limit, so just wait for it to finish.
|
||||
await service_task
|
||||
return
|
||||
|
||||
await self._async_run_long_action(service_task)
|
||||
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
@ -638,165 +522,12 @@ class _QueuedScriptRun(_ScriptRun):
|
||||
|
||||
def _finish(self):
|
||||
# pylint: disable=protected-access
|
||||
self._script._queue_len -= 1
|
||||
if self.lock_acquired:
|
||||
self._script._queue_lck.release()
|
||||
self.lock_acquired = False
|
||||
super()._finish()
|
||||
|
||||
|
||||
class _LegacyScriptRun(_ScriptRunBase):
|
||||
"""Manage legacy Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: Optional[Sequence],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
shared: Optional["_LegacyScriptRun"],
|
||||
) -> None:
|
||||
super().__init__(hass, script, variables, context, log_exceptions)
|
||||
if shared:
|
||||
self._shared = shared
|
||||
else:
|
||||
# To implement legacy behavior we need to share the following "run state"
|
||||
# amongst all runs, so it will only exist in the first instantiation of
|
||||
# concurrent runs, and the rest will use it, too.
|
||||
self._current = -1
|
||||
self._async_listeners: List[CALLBACK_TYPE] = []
|
||||
self._shared = self
|
||||
|
||||
@property
|
||||
def _cur(self):
|
||||
return self._shared._current # pylint: disable=protected-access
|
||||
|
||||
@_cur.setter
|
||||
def _cur(self, value):
|
||||
self._shared._current = value # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _async_listener(self):
|
||||
return self._shared._async_listeners # pylint: disable=protected-access
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
await self._async_run()
|
||||
|
||||
async def _async_run(self, propagate_exceptions=True):
|
||||
if self._cur == -1:
|
||||
self._script.last_triggered = utcnow()
|
||||
self._log("Running script")
|
||||
self._cur = 0
|
||||
|
||||
# Unregister callback if we were in a delay or wait but turn on is
|
||||
# called again. In that case we just continue execution.
|
||||
self._async_remove_listener()
|
||||
|
||||
suspended = False
|
||||
try:
|
||||
for self._step, self._action in itertools.islice(
|
||||
enumerate(self._script.sequence), self._cur, None
|
||||
):
|
||||
await self._async_step(log_exceptions=not propagate_exceptions)
|
||||
except _StopScript:
|
||||
pass
|
||||
except _SuspendScript:
|
||||
# Store next step to take and notify change listeners
|
||||
self._cur = self._step + 1
|
||||
suspended = True
|
||||
return
|
||||
except Exception: # pylint: disable=broad-except
|
||||
if propagate_exceptions:
|
||||
raise
|
||||
finally:
|
||||
_cur_was = self._cur
|
||||
if not suspended:
|
||||
self._script.last_action = None
|
||||
await self.async_stop()
|
||||
if _cur_was != -1:
|
||||
self._changed()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
if self._cur == -1:
|
||||
return
|
||||
|
||||
self._cur = -1
|
||||
self._async_remove_listener()
|
||||
self._script._runs.clear() # pylint: disable=protected-access
|
||||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
delay = self._prep_delay_step()
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Handle delay."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self._hass, async_script_delay, utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
|
||||
@callback
|
||||
def async_script_timeout(now):
|
||||
"""Call after timeout has expired."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub_timeout)
|
||||
|
||||
# Check if we want to continue to execute
|
||||
# the script after the timeout
|
||||
if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
else:
|
||||
self._log(_TIMEOUT_MSG)
|
||||
self._hass.async_create_task(self.async_stop())
|
||||
|
||||
unsub_wait = self._prep_wait_template_step(async_script_wait)
|
||||
if not unsub_wait:
|
||||
return
|
||||
self._async_listener.append(unsub_wait)
|
||||
|
||||
if CONF_TIMEOUT in self._action:
|
||||
unsub_timeout = async_track_point_in_utc_time(
|
||||
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
|
||||
)
|
||||
self._async_listener.append(unsub_timeout)
|
||||
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
await self._hass.services.async_call(
|
||||
*self._prep_call_service_step(), blocking=True, context=self._context
|
||||
)
|
||||
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
# Not supported in legacy mode.
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove listeners, if any."""
|
||||
for unsub in self._async_listener:
|
||||
unsub()
|
||||
self._async_listener.clear()
|
||||
|
||||
|
||||
class Script:
|
||||
"""Representation of a script."""
|
||||
|
||||
@ -807,7 +538,7 @@ class Script:
|
||||
name: Optional[str] = None,
|
||||
change_listener: Optional[Callable[..., Any]] = None,
|
||||
script_mode: str = DEFAULT_SCRIPT_MODE,
|
||||
queue_size: int = DEFAULT_QUEUE_SIZE,
|
||||
max_runs: int = DEFAULT_MAX,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
log_exceptions: bool = True,
|
||||
) -> None:
|
||||
@ -829,10 +560,7 @@ class Script:
|
||||
|
||||
self.last_action = None
|
||||
self.last_triggered: Optional[datetime] = None
|
||||
self.can_cancel = not self.is_legacy or any(
|
||||
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
|
||||
for action in self.sequence
|
||||
)
|
||||
self.can_cancel = True
|
||||
|
||||
self._repeat_script = {}
|
||||
for step, action in enumerate(sequence):
|
||||
@ -850,10 +578,9 @@ class Script:
|
||||
)
|
||||
self._repeat_script[step] = sub_script
|
||||
|
||||
self._runs: List[_ScriptRunBase] = []
|
||||
self._runs: List[_ScriptRun] = []
|
||||
self._max_runs = max_runs
|
||||
if script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._queue_size = queue_size
|
||||
self._queue_len = 0
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
@ -873,11 +600,6 @@ class Script:
|
||||
"""Return true if script is on."""
|
||||
return len(self._runs) > 0
|
||||
|
||||
@property
|
||||
def is_legacy(self) -> bool:
|
||||
"""Return if using legacy mode."""
|
||||
return self._script_mode == SCRIPT_MODE_LEGACY
|
||||
|
||||
@property
|
||||
def referenced_devices(self):
|
||||
"""Return a set of referenced devices."""
|
||||
@ -945,60 +667,40 @@ class Script:
|
||||
) -> None:
|
||||
"""Run script."""
|
||||
if self.is_running:
|
||||
if self._script_mode == SCRIPT_MODE_IGNORE:
|
||||
self._log("Skipping script")
|
||||
if self._script_mode == SCRIPT_MODE_SINGLE:
|
||||
self._log("Already running", level=logging.WARNING)
|
||||
return
|
||||
if self._script_mode == SCRIPT_MODE_RESTART:
|
||||
self._log("Restarting")
|
||||
await self.async_stop(update_state=False)
|
||||
elif len(self._runs) == self._max_runs:
|
||||
self._log("Maximum number of runs exceeded", level=logging.WARNING)
|
||||
return
|
||||
|
||||
if self._script_mode == SCRIPT_MODE_ERROR:
|
||||
raise AlreadyRunning
|
||||
|
||||
if self._script_mode == SCRIPT_MODE_RESTART:
|
||||
self._log("Restarting script")
|
||||
await self.async_stop(update_state=False)
|
||||
elif self._script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._log(
|
||||
"Queueing script behind %i run%s",
|
||||
self._queue_len,
|
||||
"s" if self._queue_len > 1 else "",
|
||||
)
|
||||
if self._queue_len >= self._queue_size:
|
||||
raise QueueFull
|
||||
|
||||
if self.is_legacy:
|
||||
if self._runs:
|
||||
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
|
||||
else:
|
||||
shared = None
|
||||
run: _ScriptRunBase = _LegacyScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions, shared
|
||||
)
|
||||
if self._script_mode != SCRIPT_MODE_QUEUE:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
if self._script_mode != SCRIPT_MODE_QUEUE:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
cls = _QueuedScriptRun
|
||||
self._queue_len += 1
|
||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
||||
cls = _QueuedScriptRun
|
||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
||||
self._runs.append(run)
|
||||
|
||||
try:
|
||||
if self.is_legacy:
|
||||
await run.async_run()
|
||||
else:
|
||||
await asyncio.shield(run.async_run())
|
||||
await asyncio.shield(run.async_run())
|
||||
except asyncio.CancelledError:
|
||||
await run.async_stop()
|
||||
self._changed()
|
||||
raise
|
||||
|
||||
async def async_stop(self, update_state: bool = True) -> None:
|
||||
"""Stop running script."""
|
||||
if not self.is_running:
|
||||
return
|
||||
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
|
||||
async def _async_stop(self, update_state):
|
||||
await asyncio.wait([run.async_stop() for run in self._runs])
|
||||
if update_state:
|
||||
self._changed()
|
||||
|
||||
async def async_stop(self, update_state: bool = True) -> None:
|
||||
"""Stop running script."""
|
||||
if self.is_running:
|
||||
await asyncio.shield(self._async_stop(update_state))
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
if self.name:
|
||||
msg = f"%s: {msg}"
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""The tests for the automation component."""
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import logbook
|
||||
@ -23,12 +21,7 @@ from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.async_mock import Mock, patch
|
||||
from tests.common import (
|
||||
assert_setup_component,
|
||||
async_fire_time_changed,
|
||||
async_mock_service,
|
||||
mock_restore_cache,
|
||||
)
|
||||
from tests.common import assert_setup_component, async_mock_service, mock_restore_cache
|
||||
from tests.components.automation import common
|
||||
from tests.components.logbook.test_init import MockLazyEventPartialState
|
||||
|
||||
@ -87,57 +80,6 @@ async def test_service_specify_data(hass, calls):
|
||||
assert state.attributes.get("last_triggered") == time
|
||||
|
||||
|
||||
async def test_action_delay(hass, calls):
|
||||
"""Test action delay."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [
|
||||
{
|
||||
"service": "test.automation",
|
||||
"data_template": {
|
||||
"some": "{{ trigger.platform }} - "
|
||||
"{{ trigger.event.event_type }}"
|
||||
},
|
||||
},
|
||||
{"delay": {"minutes": "10"}},
|
||||
{
|
||||
"service": "test.automation",
|
||||
"data_template": {
|
||||
"some": "{{ trigger.platform }} - "
|
||||
"{{ trigger.event.event_type }}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
time = dt_util.utcnow()
|
||||
|
||||
with patch("homeassistant.components.automation.utcnow", return_value=time):
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["some"] == "event - test_event"
|
||||
|
||||
future = dt_util.utcnow() + timedelta(minutes=10)
|
||||
async_fire_time_changed(hass, future)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[1].data["some"] == "event - test_event"
|
||||
|
||||
state = hass.states.get("automation.hello")
|
||||
assert state is not None
|
||||
assert state.attributes.get("last_triggered") == time
|
||||
|
||||
|
||||
async def test_service_specify_entity_id(hass, calls):
|
||||
"""Test service data."""
|
||||
assert await async_setup_component(
|
||||
@ -1070,42 +1012,3 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
||||
assert event2["domain"] == "automation"
|
||||
assert event2["message"] == "has been triggered"
|
||||
assert event2["entity_id"] == "automation.bye"
|
||||
|
||||
|
||||
invalid_configs = [
|
||||
{
|
||||
"mode": "parallel",
|
||||
"queue_size": 5,
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [],
|
||||
},
|
||||
{
|
||||
"mode": "legacy",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [{"repeat": {"count": 5, "sequence": []}}],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", invalid_configs)
|
||||
async def test_invalid_configs(hass, value):
|
||||
"""Test invalid configurations."""
|
||||
with assert_setup_component(0, automation.DOMAIN):
|
||||
assert await async_setup_component(
|
||||
hass, automation.DOMAIN, {automation.DOMAIN: value}
|
||||
)
|
||||
|
||||
|
||||
async def test_config_legacy(hass, caplog):
|
||||
"""Test config defaulting to legacy mode."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
assert "To continue using previous behavior, which is now deprecated" in caplog.text
|
||||
|
@ -961,10 +961,8 @@ async def test_wait_template_with_trigger(hass, calls):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.states.async_set("test.entity", "12")
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set("test.entity", "8")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert "numeric_state - test.entity - 12" == calls[0].data["some"]
|
||||
|
||||
|
@ -704,7 +704,6 @@ async def test_wait_template_with_trigger(hass, calls):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.states.async_set("test.entity", "world")
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
@ -5,7 +5,7 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
@ -435,6 +435,7 @@ async def test_wait_template_with_trigger(hass, calls):
|
||||
"value_template": "{{ states.test.entity.state == 'world' }}",
|
||||
},
|
||||
"action": [
|
||||
{"event": "test_event"},
|
||||
{"wait_template": "{{ is_state(trigger.entity_id, 'hello') }}"},
|
||||
{
|
||||
"service": "test.automation",
|
||||
@ -458,10 +459,14 @@ async def test_wait_template_with_trigger(hass, calls):
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
@callback
|
||||
def event_handler(event):
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
|
||||
hass.bus.async_listen_once("test_event", event_handler)
|
||||
|
||||
hass.states.async_set("test.entity", "world")
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["some"] == "template - test.entity - hello - world - None"
|
||||
|
||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import Context, callback, split_entity_id
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.setup import async_setup_component, setup_component
|
||||
@ -80,75 +81,6 @@ class TestScriptComponent(unittest.TestCase):
|
||||
"""Stop down everything that was started."""
|
||||
self.hass.stop()
|
||||
|
||||
def test_turn_on_service(self):
|
||||
"""Verify that the turn_on service."""
|
||||
event = "test_event"
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
assert setup_component(
|
||||
self.hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test": {"sequence": [{"delay": {"seconds": 5}}, {"event": event}]}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
turn_on(self.hass, ENTITY_ID)
|
||||
self.hass.block_till_done()
|
||||
assert script.is_on(self.hass, ENTITY_ID)
|
||||
assert 0 == len(events)
|
||||
|
||||
# Calling turn_on a second time should not advance the script
|
||||
turn_on(self.hass, ENTITY_ID)
|
||||
self.hass.block_till_done()
|
||||
assert 0 == len(events)
|
||||
|
||||
turn_off(self.hass, ENTITY_ID)
|
||||
self.hass.block_till_done()
|
||||
assert not script.is_on(self.hass, ENTITY_ID)
|
||||
assert 0 == len(events)
|
||||
|
||||
def test_toggle_service(self):
|
||||
"""Test the toggling of a service."""
|
||||
event = "test_event"
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
assert setup_component(
|
||||
self.hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test": {"sequence": [{"delay": {"seconds": 5}}, {"event": event}]}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
toggle(self.hass, ENTITY_ID)
|
||||
self.hass.block_till_done()
|
||||
assert script.is_on(self.hass, ENTITY_ID)
|
||||
assert 0 == len(events)
|
||||
|
||||
toggle(self.hass, ENTITY_ID)
|
||||
self.hass.block_till_done()
|
||||
assert not script.is_on(self.hass, ENTITY_ID)
|
||||
assert 0 == len(events)
|
||||
|
||||
def test_passing_variables(self):
|
||||
"""Test different ways of passing in variables."""
|
||||
calls = []
|
||||
@ -195,17 +127,58 @@ class TestScriptComponent(unittest.TestCase):
|
||||
assert calls[1].data["hello"] == "universe"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("toggle", [False, True])
|
||||
async def test_turn_on_off_toggle(hass, toggle):
|
||||
"""Verify turn_on, turn_off & toggle services."""
|
||||
event = "test_event"
|
||||
event_mock = Mock()
|
||||
|
||||
hass.bus.async_listen(event, event_mock)
|
||||
|
||||
was_on = False
|
||||
|
||||
@callback
|
||||
def state_listener(entity_id, old_state, new_state):
|
||||
nonlocal was_on
|
||||
was_on = True
|
||||
|
||||
async_track_state_change(hass, ENTITY_ID, state_listener, to_state="on")
|
||||
|
||||
if toggle:
|
||||
turn_off_step = {"service": "script.toggle", "entity_id": ENTITY_ID}
|
||||
else:
|
||||
turn_off_step = {"service": "script.turn_off", "entity_id": ENTITY_ID}
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test": {
|
||||
"sequence": [{"event": event}, turn_off_step, {"event": event}]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert not script.is_on(hass, ENTITY_ID)
|
||||
|
||||
if toggle:
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TOGGLE, {ATTR_ENTITY_ID: ENTITY_ID}
|
||||
)
|
||||
else:
|
||||
await hass.services.async_call(DOMAIN, split_entity_id(ENTITY_ID)[1])
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert not script.is_on(hass, ENTITY_ID)
|
||||
assert was_on
|
||||
assert 1 == event_mock.call_count
|
||||
|
||||
|
||||
invalid_configs = [
|
||||
{"test": {}},
|
||||
{"test hello world": {"sequence": [{"event": "bla"}]}},
|
||||
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
|
||||
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
|
||||
{
|
||||
"test": {
|
||||
"mode": "legacy",
|
||||
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@ -222,8 +195,29 @@ async def test_setup_with_invalid_configs(hass, value):
|
||||
@pytest.mark.parametrize("running", ["no", "same", "different"])
|
||||
async def test_reload_service(hass, running):
|
||||
"""Verify the reload service."""
|
||||
event = "test_event"
|
||||
event_flag = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def event_handler(event):
|
||||
event_flag.set()
|
||||
|
||||
hass.bus.async_listen_once(event, event_handler)
|
||||
hass.states.async_set("test.script", "off")
|
||||
|
||||
assert await async_setup_component(
|
||||
hass, "script", {"script": {"test": {"sequence": [{"delay": {"seconds": 5}}]}}}
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test": {
|
||||
"sequence": [
|
||||
{"event": event},
|
||||
{"wait_template": "{{ is_state('test.script', 'on') }}"},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert hass.states.get(ENTITY_ID) is not None
|
||||
@ -232,7 +226,7 @@ async def test_reload_service(hass, running):
|
||||
if running != "no":
|
||||
_, object_id = split_entity_id(ENTITY_ID)
|
||||
await hass.services.async_call(DOMAIN, object_id)
|
||||
await hass.async_block_till_done()
|
||||
await asyncio.wait_for(event_flag.wait(), 1)
|
||||
|
||||
assert script.is_on(hass, ENTITY_ID)
|
||||
|
||||
@ -486,14 +480,6 @@ async def test_config_basic(hass):
|
||||
assert test_script.attributes["icon"] == "mdi:party"
|
||||
|
||||
|
||||
async def test_config_legacy(hass, caplog):
|
||||
"""Test config defaulting to legacy mode."""
|
||||
assert await async_setup_component(
|
||||
hass, "script", {"script": {"test_script": {"sequence": []}}}
|
||||
)
|
||||
assert "To continue using previous behavior, which is now deprecated" in caplog.text
|
||||
|
||||
|
||||
async def test_logbook_humanify_script_started_event(hass):
|
||||
"""Test humanifying script started event."""
|
||||
hass.config.components.add("recorder")
|
||||
|
@ -27,8 +27,6 @@ from tests.common import (
|
||||
|
||||
ENTITY_ID = "script.test"
|
||||
|
||||
_BASIC_SCRIPT_MODES = ("legacy", "parallel")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_timeout(hass, monkeypatch):
|
||||
@ -86,18 +84,16 @@ def async_watch_for_action(script_obj, message):
|
||||
return flag
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_firing_event_basic(hass, script_mode):
|
||||
async def test_firing_event_basic(hass):
|
||||
"""Test the firing of events."""
|
||||
event = "test_event"
|
||||
context = Context()
|
||||
events = async_capture_events(hass, event)
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA({"event": event, "event_data": {"hello": "world"}})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.is_legacy == (script_mode == "legacy")
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
await script_obj.async_run(context=context)
|
||||
await hass.async_block_till_done()
|
||||
@ -105,11 +101,10 @@ async def test_firing_event_basic(hass, script_mode):
|
||||
assert len(events) == 1
|
||||
assert events[0].context is context
|
||||
assert events[0].data.get("hello") == "world"
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_firing_event_template(hass, script_mode):
|
||||
async def test_firing_event_template(hass):
|
||||
"""Test the firing of events."""
|
||||
event = "test_event"
|
||||
context = Context()
|
||||
@ -128,9 +123,9 @@ async def test_firing_event_template(hass, script_mode):
|
||||
},
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
||||
await hass.async_block_till_done()
|
||||
@ -143,16 +138,15 @@ async def test_firing_event_template(hass, script_mode):
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_calling_service_basic(hass, script_mode):
|
||||
async def test_calling_service_basic(hass):
|
||||
"""Test the calling of a service."""
|
||||
context = Context()
|
||||
calls = async_mock_service(hass, "test", "script")
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA({"service": "test.script", "data": {"hello": "world"}})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
await script_obj.async_run(context=context)
|
||||
await hass.async_block_till_done()
|
||||
@ -162,8 +156,7 @@ async def test_calling_service_basic(hass, script_mode):
|
||||
assert calls[0].data.get("hello") == "world"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_calling_service_template(hass, script_mode):
|
||||
async def test_calling_service_template(hass):
|
||||
"""Test the calling of a service."""
|
||||
context = Context()
|
||||
calls = async_mock_service(hass, "test", "script")
|
||||
@ -187,9 +180,9 @@ async def test_calling_service_template(hass, script_mode):
|
||||
},
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
||||
await hass.async_block_till_done()
|
||||
@ -199,8 +192,7 @@ async def test_calling_service_template(hass, script_mode):
|
||||
assert calls[0].data.get("hello") == "world"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_multiple_runs_no_wait(hass, script_mode):
|
||||
async def test_multiple_runs_no_wait(hass):
|
||||
"""Test multiple runs with no wait in script."""
|
||||
logger = logging.getLogger("TEST")
|
||||
calls = []
|
||||
@ -243,7 +235,7 @@ async def test_multiple_runs_no_wait(hass, script_mode):
|
||||
},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
|
||||
|
||||
# Start script twice in such a way that second run will be started while first run
|
||||
# is in the middle of the first service call.
|
||||
@ -267,16 +259,15 @@ async def test_multiple_runs_no_wait(hass, script_mode):
|
||||
assert len(calls) == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_activating_scene(hass, script_mode):
|
||||
async def test_activating_scene(hass):
|
||||
"""Test the activation of a scene."""
|
||||
context = Context()
|
||||
calls = async_mock_service(hass, scene.DOMAIN, SERVICE_TURN_ON)
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA({"scene": "scene.hello"})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
await script_obj.async_run(context=context)
|
||||
await hass.async_block_till_done()
|
||||
@ -287,8 +278,7 @@ async def test_activating_scene(hass, script_mode):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("count", [1, 3])
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_stop_no_wait(hass, caplog, script_mode, count):
|
||||
async def test_stop_no_wait(hass, caplog, count):
|
||||
"""Test stopping script."""
|
||||
service_started_sem = asyncio.Semaphore(0)
|
||||
finish_service_event = asyncio.Event()
|
||||
@ -303,7 +293,7 @@ async def test_stop_no_wait(hass, caplog, script_mode, count):
|
||||
hass.services.async_register("test", "script", async_simulate_long_service)
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=count)
|
||||
|
||||
# Get script started specified number of times and wait until the test.script
|
||||
# service has started for each run.
|
||||
@ -328,15 +318,14 @@ async def test_stop_no_wait(hass, caplog, script_mode, count):
|
||||
assert script_was_runing
|
||||
assert were_no_events
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == (count if script_mode == "legacy" else 0)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_delay_basic(hass, mock_timeout, script_mode):
|
||||
async def test_delay_basic(hass, mock_timeout):
|
||||
"""Test the delay."""
|
||||
delay_alias = "delay step"
|
||||
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
delay_started_flag = async_watch_for_action(script_obj, delay_alias)
|
||||
|
||||
assert script_obj.can_cancel
|
||||
@ -358,8 +347,7 @@ async def test_delay_basic(hass, mock_timeout, script_mode):
|
||||
assert script_obj.last_action is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
|
||||
async def test_multiple_runs_delay(hass, mock_timeout):
|
||||
"""Test multiple runs with delay in script."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -371,7 +359,7 @@ async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
|
||||
{"event": event, "event_data": {"value": 2}},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay")
|
||||
|
||||
try:
|
||||
@ -386,31 +374,24 @@ async def test_multiple_runs_delay(hass, mock_timeout, script_mode):
|
||||
raise
|
||||
else:
|
||||
# Start second run of script while first run is in a delay.
|
||||
if script_mode == "legacy":
|
||||
await script_obj.async_run()
|
||||
else:
|
||||
script_obj.sequence[1]["alias"] = "delay run 2"
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay run 2")
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
await asyncio.wait_for(delay_started_flag.wait(), 1)
|
||||
script_obj.sequence[1]["alias"] = "delay run 2"
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay run 2")
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
await asyncio.wait_for(delay_started_flag.wait(), 1)
|
||||
async_fire_time_changed(hass, dt_util.utcnow() + delay)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
if script_mode == "legacy":
|
||||
assert len(events) == 2
|
||||
else:
|
||||
assert len(events) == 4
|
||||
assert events[-3].data["value"] == 1
|
||||
assert events[-2].data["value"] == 2
|
||||
assert len(events) == 4
|
||||
assert events[-3].data["value"] == 1
|
||||
assert events[-2].data["value"] == 2
|
||||
assert events[-1].data["value"] == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_delay_template_ok(hass, mock_timeout, script_mode):
|
||||
async def test_delay_template_ok(hass, mock_timeout):
|
||||
"""Test the delay as a template."""
|
||||
sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay")
|
||||
|
||||
assert script_obj.can_cancel
|
||||
@ -430,8 +411,7 @@ async def test_delay_template_ok(hass, mock_timeout, script_mode):
|
||||
assert not script_obj.is_running
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_delay_template_invalid(hass, caplog, script_mode):
|
||||
async def test_delay_template_invalid(hass, caplog):
|
||||
"""Test the delay as a template that fails."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -443,7 +423,7 @@ async def test_delay_template_invalid(hass, caplog, script_mode):
|
||||
{"event": event},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
start_idx = len(caplog.records)
|
||||
|
||||
await script_obj.async_run()
|
||||
@ -458,11 +438,10 @@ async def test_delay_template_invalid(hass, caplog, script_mode):
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_delay_template_complex_ok(hass, mock_timeout, script_mode):
|
||||
async def test_delay_template_complex_ok(hass, mock_timeout):
|
||||
"""Test the delay with a working complex template."""
|
||||
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay")
|
||||
|
||||
assert script_obj.can_cancel
|
||||
@ -481,8 +460,7 @@ async def test_delay_template_complex_ok(hass, mock_timeout, script_mode):
|
||||
assert not script_obj.is_running
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_delay_template_complex_invalid(hass, caplog, script_mode):
|
||||
async def test_delay_template_complex_invalid(hass, caplog):
|
||||
"""Test the delay with a complex template that fails."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -494,7 +472,7 @@ async def test_delay_template_complex_invalid(hass, caplog, script_mode):
|
||||
{"event": event},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
start_idx = len(caplog.records)
|
||||
|
||||
await script_obj.async_run()
|
||||
@ -509,13 +487,12 @@ async def test_delay_template_complex_invalid(hass, caplog, script_mode):
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_cancel_delay(hass, script_mode):
|
||||
async def test_cancel_delay(hass):
|
||||
"""Test the cancelling while the delay is present."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA([{"delay": {"seconds": 5}}, {"event": event}])
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
delay_started_flag = async_watch_for_action(script_obj, "delay")
|
||||
|
||||
try:
|
||||
@ -541,8 +518,7 @@ async def test_cancel_delay(hass, script_mode):
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_wait_template_basic(hass, script_mode):
|
||||
async def test_wait_template_basic(hass):
|
||||
"""Test the wait template."""
|
||||
wait_alias = "wait step"
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
@ -551,7 +527,7 @@ async def test_wait_template_basic(hass, script_mode):
|
||||
"alias": wait_alias,
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||
|
||||
assert script_obj.can_cancel
|
||||
@ -574,8 +550,7 @@ async def test_wait_template_basic(hass, script_mode):
|
||||
assert script_obj.last_action is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_multiple_runs_wait_template(hass, script_mode):
|
||||
async def test_multiple_runs_wait_template(hass):
|
||||
"""Test multiple runs with wait_template in script."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -586,7 +561,7 @@ async def test_multiple_runs_wait_template(hass, script_mode):
|
||||
{"event": event, "event_data": {"value": 2}},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
@ -602,25 +577,18 @@ async def test_multiple_runs_wait_template(hass, script_mode):
|
||||
raise
|
||||
else:
|
||||
# Start second run of script while first run is in wait_template.
|
||||
if script_mode == "legacy":
|
||||
await script_obj.async_run()
|
||||
else:
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
hass.states.async_set("switch.test", "off")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
if script_mode == "legacy":
|
||||
assert len(events) == 2
|
||||
else:
|
||||
assert len(events) == 4
|
||||
assert events[-3].data["value"] == 1
|
||||
assert events[-2].data["value"] == 2
|
||||
assert len(events) == 4
|
||||
assert events[-3].data["value"] == 1
|
||||
assert events[-2].data["value"] == 2
|
||||
assert events[-1].data["value"] == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_cancel_wait_template(hass, script_mode):
|
||||
async def test_cancel_wait_template(hass):
|
||||
"""Test the cancelling while wait_template is present."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -630,7 +598,7 @@ async def test_cancel_wait_template(hass, script_mode):
|
||||
{"event": event},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
@ -657,8 +625,7 @@ async def test_cancel_wait_template(hass, script_mode):
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_wait_template_not_schedule(hass, script_mode):
|
||||
async def test_wait_template_not_schedule(hass):
|
||||
"""Test the wait template with correct condition."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -669,7 +636,7 @@ async def test_wait_template_not_schedule(hass, script_mode):
|
||||
{"event": event},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
hass.states.async_set("switch.test", "on")
|
||||
await script_obj.async_run()
|
||||
@ -679,13 +646,10 @@ async def test_wait_template_not_schedule(hass, script_mode):
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
@pytest.mark.parametrize(
|
||||
"continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)]
|
||||
)
|
||||
async def test_wait_template_timeout(
|
||||
hass, mock_timeout, continue_on_timeout, n_events, script_mode
|
||||
):
|
||||
async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_events):
|
||||
"""Test the wait template, halt on timeout."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -696,7 +660,7 @@ async def test_wait_template_timeout(
|
||||
if continue_on_timeout is not None:
|
||||
sequence[0]["continue_on_timeout"] = continue_on_timeout
|
||||
sequence = cv.SCRIPT_SCHEMA(sequence)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
@ -717,11 +681,10 @@ async def test_wait_template_timeout(
|
||||
assert len(events) == n_events
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_wait_template_variables(hass, script_mode):
|
||||
async def test_wait_template_variables(hass):
|
||||
"""Test the wait template with variables."""
|
||||
sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
assert script_obj.can_cancel
|
||||
@ -742,8 +705,7 @@ async def test_wait_template_variables(hass, script_mode):
|
||||
assert not script_obj.is_running
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_condition_basic(hass, script_mode):
|
||||
async def test_condition_basic(hass):
|
||||
"""Test if we can use conditions in a script."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -757,9 +719,9 @@ async def test_condition_basic(hass, script_mode):
|
||||
{"event": event},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.can_cancel == (script_mode != "legacy")
|
||||
assert script_obj.can_cancel
|
||||
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
await script_obj.async_run()
|
||||
@ -775,9 +737,8 @@ async def test_condition_basic(hass, script_mode):
|
||||
assert len(events) == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
@patch("homeassistant.helpers.script.condition.async_from_config")
|
||||
async def test_condition_created_once(async_from_config, hass, script_mode):
|
||||
async def test_condition_created_once(async_from_config, hass):
|
||||
"""Test that the conditions do not get created multiple times."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{
|
||||
@ -785,7 +746,7 @@ async def test_condition_created_once(async_from_config, hass, script_mode):
|
||||
"value_template": '{{ states.test.entity.state == "hello" }}',
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence, script_mode="parallel", max_runs=2)
|
||||
|
||||
async_from_config.reset_mock()
|
||||
|
||||
@ -798,8 +759,7 @@ async def test_condition_created_once(async_from_config, hass, script_mode):
|
||||
assert len(script_obj._config_cache) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_condition_all_cached(hass, script_mode):
|
||||
async def test_condition_all_cached(hass):
|
||||
"""Test that multiple conditions get cached."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
[
|
||||
@ -813,7 +773,7 @@ async def test_condition_all_cached(hass, script_mode):
|
||||
},
|
||||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
await script_obj.async_run()
|
||||
@ -843,7 +803,7 @@ async def test_repeat_count(hass):
|
||||
}
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode="ignore")
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
await script_obj.async_run()
|
||||
await hass.async_block_till_done()
|
||||
@ -887,7 +847,7 @@ async def test_repeat_conditional(hass, condition):
|
||||
"condition": "template",
|
||||
"value_template": "{{ is_state('sensor.test', 'done') }}",
|
||||
}
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence))
|
||||
|
||||
wait_started = async_watch_for_action(script_obj, "wait")
|
||||
hass.states.async_set("sensor.test", "1")
|
||||
@ -917,12 +877,11 @@ async def test_repeat_conditional(hass, condition):
|
||||
assert event.data.get("index") == str(index + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_last_triggered(hass, script_mode):
|
||||
async def test_last_triggered(hass):
|
||||
"""Test the last_triggered."""
|
||||
event = "test_event"
|
||||
sequence = cv.SCRIPT_SCHEMA({"event": event})
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
assert script_obj.last_triggered is None
|
||||
|
||||
@ -934,13 +893,12 @@ async def test_last_triggered(hass, script_mode):
|
||||
assert script_obj.last_triggered == time
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_propagate_error_service_not_found(hass, script_mode):
|
||||
async def test_propagate_error_service_not_found(hass):
|
||||
"""Test that a script aborts when a service is not found."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
with pytest.raises(exceptions.ServiceNotFound):
|
||||
await script_obj.async_run()
|
||||
@ -949,8 +907,7 @@ async def test_propagate_error_service_not_found(hass, script_mode):
|
||||
assert not script_obj.is_running
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_propagate_error_invalid_service_data(hass, script_mode):
|
||||
async def test_propagate_error_invalid_service_data(hass):
|
||||
"""Test that a script aborts when we send invalid service data."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -958,7 +915,7 @@ async def test_propagate_error_invalid_service_data(hass, script_mode):
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
[{"service": "test.script", "data": {"text": 1}}, {"event": event}]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
await script_obj.async_run()
|
||||
@ -968,8 +925,7 @@ async def test_propagate_error_invalid_service_data(hass, script_mode):
|
||||
assert not script_obj.is_running
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_propagate_error_service_exception(hass, script_mode):
|
||||
async def test_propagate_error_service_exception(hass):
|
||||
"""Test that a script aborts when a service throws an exception."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -982,7 +938,7 @@ async def test_propagate_error_service_exception(hass, script_mode):
|
||||
hass.services.async_register("test", "script", record_call)
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": event}])
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await script_obj.async_run()
|
||||
@ -1053,15 +1009,8 @@ def does_not_raise():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"script_mode,expectation,messages",
|
||||
[
|
||||
("ignore", does_not_raise(), ["Skipping"]),
|
||||
("error", pytest.raises(exceptions.HomeAssistantError), []),
|
||||
],
|
||||
)
|
||||
async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
|
||||
"""Test overlapping runs with script_mode='ignore'."""
|
||||
async def test_script_mode_single(hass, caplog):
|
||||
"""Test overlapping runs with max_runs = 1."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
@ -1071,8 +1020,7 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
|
||||
{"event": event, "event_data": {"value": 2}},
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("TEST")
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
@ -1086,19 +1034,10 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
|
||||
|
||||
# Start second run of script while first run is suspended in wait_template.
|
||||
|
||||
with expectation:
|
||||
await script_obj.async_run()
|
||||
await script_obj.async_run()
|
||||
|
||||
assert "Already running" in caplog.text
|
||||
assert script_obj.is_running
|
||||
assert all(
|
||||
any(
|
||||
rec.levelname == "INFO"
|
||||
and rec.name == "TEST"
|
||||
and message in rec.message
|
||||
for rec in caplog.records
|
||||
)
|
||||
for message in messages
|
||||
)
|
||||
except (AssertionError, asyncio.TimeoutError):
|
||||
await script_obj.async_stop()
|
||||
raise
|
||||
@ -1116,7 +1055,7 @@ async def test_script_mode_1(hass, caplog, script_mode, expectation, messages):
|
||||
[("restart", ["Restarting"], [2]), ("parallel", [], [2, 2])],
|
||||
)
|
||||
async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
|
||||
"""Test overlapping runs with script_mode='restart'."""
|
||||
"""Test overlapping runs with max_runs > 1."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
@ -1127,7 +1066,10 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("TEST")
|
||||
script_obj = script.Script(hass, sequence, script_mode=script_mode, logger=logger)
|
||||
max_runs = 1 if script_mode == "restart" else 2
|
||||
script_obj = script.Script(
|
||||
hass, sequence, script_mode=script_mode, max_runs=max_runs, logger=logger
|
||||
)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
@ -1140,7 +1082,6 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
|
||||
assert events[0].data["value"] == 1
|
||||
|
||||
# Start second run of script while first run is suspended in wait_template.
|
||||
# This should stop first run then start a new run.
|
||||
|
||||
wait_started_flag.clear()
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
@ -1172,7 +1113,7 @@ async def test_script_mode_2(hass, caplog, script_mode, messages, last_events):
|
||||
|
||||
|
||||
async def test_script_mode_queue(hass):
|
||||
"""Test overlapping runs with script_mode='queue'."""
|
||||
"""Test overlapping runs with script_mode = 'queue' & max_runs > 1."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
@ -1184,7 +1125,9 @@ async def test_script_mode_queue(hass):
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("TEST")
|
||||
script_obj = script.Script(hass, sequence, script_mode="queue", logger=logger)
|
||||
script_obj = script.Script(
|
||||
hass, sequence, script_mode="queue", max_runs=2, logger=logger
|
||||
)
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user