mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Fix variable scopes in scripts (#138883)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
bd80a78848
commit
b964bc58be
@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Literal, TypedDict, cast, overload
|
||||
|
||||
import async_interrupt
|
||||
@ -90,7 +89,7 @@ from . import condition, config_validation as cv, service, template
|
||||
from .condition import ConditionCheckerType, trace_condition_function
|
||||
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
||||
from .event import async_call_later, async_track_template
|
||||
from .script_variables import ScriptVariables
|
||||
from .script_variables import ScriptRunVariables, ScriptVariables
|
||||
from .template import Template
|
||||
from .trace import (
|
||||
TraceElement,
|
||||
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
|
||||
future.set_result(None)
|
||||
|
||||
|
||||
def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
|
||||
def action_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
trace_element = TraceElement(variables, path)
|
||||
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
|
||||
@ -189,7 +188,7 @@ async def trace_action(
|
||||
hass: HomeAssistant,
|
||||
script_run: _ScriptRun,
|
||||
stop: asyncio.Future[None],
|
||||
variables: dict[str, Any],
|
||||
variables: TemplateVarsType,
|
||||
) -> AsyncGenerator[TraceElement]:
|
||||
"""Trace action execution."""
|
||||
path = trace_path_get()
|
||||
@ -411,7 +410,7 @@ class _ScriptRun:
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: Script,
|
||||
variables: dict[str, Any],
|
||||
variables: ScriptRunVariables,
|
||||
context: Context | None,
|
||||
log_exceptions: bool,
|
||||
) -> None:
|
||||
@ -485,14 +484,16 @@ class _ScriptRun:
|
||||
script_stack.pop()
|
||||
self._finish()
|
||||
|
||||
return ScriptRunResult(self._conversation_response, response, self._variables)
|
||||
return ScriptRunResult(
|
||||
self._conversation_response, response, self._variables.local_scope
|
||||
)
|
||||
|
||||
async def _async_step(self, log_exceptions: bool) -> None:
|
||||
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
|
||||
|
||||
with trace_path(str(self._step)):
|
||||
async with trace_action(
|
||||
self._hass, self, self._stop, self._variables
|
||||
self._hass, self, self._stop, self._variables.non_parallel_scope
|
||||
) as trace_element:
|
||||
if self._stop.done():
|
||||
return
|
||||
@ -526,7 +527,7 @@ class _ScriptRun:
|
||||
ex, continue_on_error, self._log_exceptions or log_exceptions
|
||||
)
|
||||
finally:
|
||||
trace_element.update_variables(self._variables)
|
||||
trace_element.update_variables(self._variables.non_parallel_scope)
|
||||
|
||||
def _finish(self) -> None:
|
||||
self._script._runs.remove(self) # noqa: SLF001
|
||||
@ -624,11 +625,16 @@ class _ScriptRun:
|
||||
except ScriptStoppedError as ex:
|
||||
raise asyncio.CancelledError from ex
|
||||
|
||||
async def _async_run_script(self, script: Script) -> None:
|
||||
async def _async_run_script(
|
||||
self, script: Script, *, parallel: bool = False
|
||||
) -> None:
|
||||
"""Execute a script."""
|
||||
result = await self._async_run_long_action(
|
||||
self._hass.async_create_task_internal(
|
||||
script.async_run(self._variables, self._context), eager_start=True
|
||||
script.async_run(
|
||||
self._variables.enter_scope(parallel=parallel), self._context
|
||||
),
|
||||
eager_start=True,
|
||||
)
|
||||
)
|
||||
if result and result.conversation_response is not UNDEFINED:
|
||||
@ -647,7 +653,7 @@ class _ScriptRun:
|
||||
"""Run a script with a trace path."""
|
||||
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
|
||||
with trace_path([str(idx), "sequence"]):
|
||||
await self._async_run_script(script)
|
||||
await self._async_run_script(script, parallel=True)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
|
||||
@ -760,14 +766,11 @@ class _ScriptRun:
|
||||
with trace_path("else"):
|
||||
await self._async_run_script(if_data["if_else"])
|
||||
|
||||
@async_trace_path("repeat")
|
||||
async def _async_step_repeat(self) -> None: # noqa: C901
|
||||
"""Repeat a sequence."""
|
||||
async def _async_do_step_repeat(self) -> None: # noqa: C901
|
||||
"""Repeat a sequence helper."""
|
||||
description = self._action.get(CONF_ALIAS, "sequence")
|
||||
repeat = self._action[CONF_REPEAT]
|
||||
|
||||
saved_repeat_vars = self._variables.get("repeat")
|
||||
|
||||
def set_repeat_var(
|
||||
iteration: int, count: int | None = None, item: Any = None
|
||||
) -> None:
|
||||
@ -776,7 +779,7 @@ class _ScriptRun:
|
||||
repeat_vars["last"] = iteration == count
|
||||
if item is not None:
|
||||
repeat_vars["item"] = item
|
||||
self._variables["repeat"] = repeat_vars
|
||||
self._variables.define_local("repeat", repeat_vars)
|
||||
|
||||
script = self._script._get_repeat_script(self._step) # noqa: SLF001
|
||||
warned_too_many_loops = False
|
||||
@ -927,10 +930,14 @@ class _ScriptRun:
|
||||
# while all the cpu time is consumed.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if saved_repeat_vars:
|
||||
self._variables["repeat"] = saved_repeat_vars
|
||||
else:
|
||||
self._variables.pop("repeat", None) # Not set if count = 0
|
||||
@async_trace_path("repeat")
|
||||
async def _async_step_repeat(self) -> None:
|
||||
"""Repeat a sequence."""
|
||||
self._variables = self._variables.enter_scope()
|
||||
try:
|
||||
await self._async_do_step_repeat()
|
||||
finally:
|
||||
self._variables = self._variables.exit_scope()
|
||||
|
||||
### Stop actions ###
|
||||
|
||||
@ -959,11 +966,12 @@ class _ScriptRun:
|
||||
## Variable actions ##
|
||||
|
||||
async def _async_step_variables(self) -> None:
|
||||
"""Set a variable value."""
|
||||
self._step_log("setting variables")
|
||||
self._variables = self._action[CONF_VARIABLES].async_render(
|
||||
self._hass, self._variables, render_as_defaults=False
|
||||
)
|
||||
"""Define a local variable."""
|
||||
self._step_log("defining local variables")
|
||||
for key, value in (
|
||||
self._action[CONF_VARIABLES].async_simple_render(self._variables).items()
|
||||
):
|
||||
self._variables.define_local(key, value)
|
||||
|
||||
## External actions ##
|
||||
|
||||
@ -1016,7 +1024,7 @@ class _ScriptRun:
|
||||
"""Perform the device automation specified in the action."""
|
||||
self._step_log("device automation")
|
||||
await device_action.async_call_action_from_config(
|
||||
self._hass, self._action, self._variables, self._context
|
||||
self._hass, self._action, dict(self._variables), self._context
|
||||
)
|
||||
|
||||
async def _async_step_event(self) -> None:
|
||||
@ -1189,12 +1197,15 @@ class _ScriptRun:
|
||||
|
||||
self._step_log("wait for trigger", timeout)
|
||||
|
||||
variables = {**self._variables}
|
||||
self._variables["wait"] = {
|
||||
"remaining": timeout,
|
||||
"completed": False,
|
||||
"trigger": None,
|
||||
}
|
||||
variables = dict(self._variables)
|
||||
self._variables.assign_parallel_protected(
|
||||
"wait",
|
||||
{
|
||||
"remaining": timeout,
|
||||
"completed": False,
|
||||
"trigger": None,
|
||||
},
|
||||
)
|
||||
trace_set_result(wait=self._variables["wait"])
|
||||
|
||||
if timeout == 0:
|
||||
@ -1240,7 +1251,9 @@ class _ScriptRun:
|
||||
timeout = self._get_timeout_seconds_from_action()
|
||||
self._step_log("wait template", timeout)
|
||||
|
||||
self._variables["wait"] = {"remaining": timeout, "completed": False}
|
||||
self._variables.assign_parallel_protected(
|
||||
"wait", {"remaining": timeout, "completed": False}
|
||||
)
|
||||
trace_set_result(wait=self._variables["wait"])
|
||||
|
||||
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
||||
@ -1369,7 +1382,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
|
||||
)
|
||||
|
||||
|
||||
type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
|
||||
type _VarsType = dict[str, Any] | Mapping[str, Any] | ScriptRunVariables
|
||||
|
||||
|
||||
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
||||
@ -1407,7 +1420,7 @@ class ScriptRunResult:
|
||||
|
||||
conversation_response: str | None | UndefinedType
|
||||
service_response: ServiceResponse
|
||||
variables: dict[str, Any]
|
||||
variables: Mapping[str, Any]
|
||||
|
||||
|
||||
class Script:
|
||||
@ -1422,7 +1435,6 @@ class Script:
|
||||
*,
|
||||
# Used in "Running <running_description>" log message
|
||||
change_listener: Callable[[], Any] | None = None,
|
||||
copy_variables: bool = False,
|
||||
log_exceptions: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
max_exceeded: str = DEFAULT_MAX_EXCEEDED,
|
||||
@ -1476,8 +1488,6 @@ class Script:
|
||||
self._parallel_scripts: dict[int, list[Script]] = {}
|
||||
self._sequence_scripts: dict[int, Script] = {}
|
||||
self.variables = variables
|
||||
self._variables_dynamic = template.is_complex(variables)
|
||||
self._copy_variables_on_run = copy_variables
|
||||
|
||||
@property
|
||||
def change_listener(self) -> Callable[..., Any] | None:
|
||||
@ -1755,25 +1765,19 @@ class Script:
|
||||
if self.top_level:
|
||||
if self.variables:
|
||||
try:
|
||||
variables = self.variables.async_render(
|
||||
run_variables = self.variables.async_render(
|
||||
self._hass,
|
||||
run_variables,
|
||||
)
|
||||
except exceptions.TemplateError as err:
|
||||
self._log("Error rendering variables: %s", err, level=logging.ERROR)
|
||||
raise
|
||||
elif run_variables:
|
||||
variables = dict(run_variables)
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
variables = ScriptRunVariables.create_top_level(run_variables)
|
||||
variables["context"] = context
|
||||
elif self._copy_variables_on_run:
|
||||
# This is not the top level script, variables have been turned to a dict
|
||||
variables = cast(dict[str, Any], copy(run_variables))
|
||||
else:
|
||||
# This is not the top level script, variables have been turned to a dict
|
||||
variables = cast(dict[str, Any], run_variables)
|
||||
# This is not the top level script, run_variables is an instance of ScriptRunVariables
|
||||
variables = cast(ScriptRunVariables, run_variables)
|
||||
|
||||
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
||||
# stop (restart) or wait for (queued) our own script run.
|
||||
@ -1999,7 +2003,6 @@ class Script:
|
||||
max_runs=self.max_runs,
|
||||
logger=self._logger,
|
||||
top_level=False,
|
||||
copy_variables=True,
|
||||
)
|
||||
parallel_script.change_listener = partial(
|
||||
self._chain_change_listener, parallel_script
|
||||
|
@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import ChainMap, UserDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
@ -24,30 +26,23 @@ class ScriptVariables:
|
||||
hass: HomeAssistant,
|
||||
run_variables: Mapping[str, Any] | None,
|
||||
*,
|
||||
render_as_defaults: bool = True,
|
||||
limited: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
The run variables are used to compute the static variables.
|
||||
|
||||
If `render_as_defaults` is True, the run variables will not be overridden.
|
||||
|
||||
The run variables are included in the result.
|
||||
The run variables are used to compute the rendered variable values.
|
||||
The run variables will not be overridden.
|
||||
The rendering happens one at a time, with previous results influencing the next.
|
||||
"""
|
||||
if self._has_template is None:
|
||||
self._has_template = template.is_complex(self.variables)
|
||||
|
||||
if not self._has_template:
|
||||
if render_as_defaults:
|
||||
rendered_variables = dict(self.variables)
|
||||
rendered_variables = dict(self.variables)
|
||||
|
||||
if run_variables is not None:
|
||||
rendered_variables.update(run_variables)
|
||||
else:
|
||||
rendered_variables = (
|
||||
{} if run_variables is None else dict(run_variables)
|
||||
)
|
||||
rendered_variables.update(self.variables)
|
||||
if run_variables is not None:
|
||||
rendered_variables.update(run_variables)
|
||||
|
||||
return rendered_variables
|
||||
|
||||
@ -56,7 +51,7 @@ class ScriptVariables:
|
||||
for key, value in self.variables.items():
|
||||
# We can skip if we're going to override this key with
|
||||
# run variables anyway
|
||||
if render_as_defaults and key in rendered_variables:
|
||||
if key in rendered_variables:
|
||||
continue
|
||||
|
||||
rendered_variables[key] = template.render_complex(
|
||||
@ -65,6 +60,197 @@ class ScriptVariables:
|
||||
|
||||
return rendered_variables
|
||||
|
||||
@callback
|
||||
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
Simply renders the variables, the run variables are not included in the result.
|
||||
The run variables are used to compute the rendered variable values.
|
||||
The rendering happens one at a time, with previous results influencing the next.
|
||||
"""
|
||||
if self._has_template is None:
|
||||
self._has_template = template.is_complex(self.variables)
|
||||
|
||||
if not self._has_template:
|
||||
return self.variables
|
||||
|
||||
run_variables = dict(run_variables)
|
||||
rendered_variables = {}
|
||||
|
||||
for key, value in self.variables.items():
|
||||
rendered_variable = template.render_complex(value, run_variables)
|
||||
rendered_variables[key] = rendered_variable
|
||||
run_variables[key] = rendered_variable
|
||||
|
||||
return rendered_variables
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return dict version of this class."""
|
||||
return self.variables
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ParallelData:
|
||||
"""Data used in each parallel sequence."""
|
||||
|
||||
# `protected` is for variables that need special protection in parallel sequences.
|
||||
# What this means is that such a variable defined in one parallel sequence will not be
|
||||
# clobbered by the variable with the same name assigned in another parallel sequence.
|
||||
# It also means that such a variable will not be visible in the outer scope.
|
||||
# Currently the only such variable is `wait`.
|
||||
protected: dict[str, Any] = field(default_factory=dict)
|
||||
# `outer_scope_writes` is for variables that are written to the outer scope from
|
||||
# a parallel sequence. This is used for generating correct traces of changed variables
|
||||
# for each of the parallel sequences, isolating them from one another.
|
||||
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ScriptRunVariables(UserDict[str, Any]):
|
||||
"""Class to hold script run variables.
|
||||
|
||||
The purpose of this class is to provide proper variable scoping semantics for scripts.
|
||||
Each instance institutes a new local scope, in which variables can be defined.
|
||||
Each instance has a reference to the previous instance, except for the top-level instance.
|
||||
The instances therefore form a chain, in which variable lookup and assignment is performed.
|
||||
The variables defined lower in the chain naturally override those defined higher up.
|
||||
"""
|
||||
|
||||
# _previous is the previous ScriptRunVariables in the chain
|
||||
_previous: ScriptRunVariables | None = None
|
||||
# _parent is the previous non-empty ScriptRunVariables in the chain
|
||||
_parent: ScriptRunVariables | None = None
|
||||
|
||||
# _local_data is the store for local variables
|
||||
_local_data: dict[str, Any] | None = None
|
||||
# _parallel_data is used for each parallel sequence
|
||||
_parallel_data: _ParallelData | None = None
|
||||
|
||||
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
|
||||
_non_parallel_scope: ChainMap[str, Any]
|
||||
# _full_scope includes all scopes (all the way to the top-level)
|
||||
_full_scope: ChainMap[str, Any]
|
||||
|
||||
@classmethod
|
||||
def create_top_level(
|
||||
cls,
|
||||
initial_data: Mapping[str, Any] | None = None,
|
||||
) -> ScriptRunVariables:
|
||||
"""Create a new top-level ScriptRunVariables."""
|
||||
local_data: dict[str, Any] = {}
|
||||
non_parallel_scope = full_scope = ChainMap(local_data)
|
||||
self = cls(
|
||||
_local_data=local_data,
|
||||
_non_parallel_scope=non_parallel_scope,
|
||||
_full_scope=full_scope,
|
||||
)
|
||||
if initial_data is not None:
|
||||
self.update(initial_data)
|
||||
return self
|
||||
|
||||
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
|
||||
"""Return a new child scope.
|
||||
|
||||
:param parallel: Whether the new scope starts a parallel sequence.
|
||||
"""
|
||||
if self._local_data is not None or self._parallel_data is not None:
|
||||
parent = self
|
||||
else:
|
||||
parent = cast( # top level always has local data, so we can cast safely
|
||||
ScriptRunVariables, self._parent
|
||||
)
|
||||
|
||||
parallel_data: _ParallelData | None
|
||||
if not parallel:
|
||||
parallel_data = None
|
||||
non_parallel_scope = self._non_parallel_scope
|
||||
full_scope = self._full_scope
|
||||
else:
|
||||
parallel_data = _ParallelData()
|
||||
non_parallel_scope = ChainMap(
|
||||
parallel_data.protected, parallel_data.outer_scope_writes
|
||||
)
|
||||
full_scope = self._full_scope.new_child(parallel_data.protected)
|
||||
|
||||
return ScriptRunVariables(
|
||||
_previous=self,
|
||||
_parent=parent,
|
||||
_parallel_data=parallel_data,
|
||||
_non_parallel_scope=non_parallel_scope,
|
||||
_full_scope=full_scope,
|
||||
)
|
||||
|
||||
def exit_scope(self) -> ScriptRunVariables:
|
||||
"""Exit the current scope.
|
||||
|
||||
Does no clean-up, but simply returns the previous scope.
|
||||
"""
|
||||
if self._previous is None:
|
||||
raise ValueError("Cannot exit top-level scope")
|
||||
return self._previous
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Delete a variable (disallowed)."""
|
||||
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""Assign value to a variable."""
|
||||
self._assign(key, value, parallel_protected=False)
|
||||
|
||||
def assign_parallel_protected(self, key: str, value: Any) -> None:
|
||||
"""Assign value to a variable which is to be protected in parallel sequences."""
|
||||
self._assign(key, value, parallel_protected=True)
|
||||
|
||||
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
|
||||
"""Assign value to a variable.
|
||||
|
||||
Value is always assigned to the variable in the nearest scope, in which it is defined.
|
||||
If the variable is not defined at all, it is created in the top-level scope.
|
||||
|
||||
:param parallel_protected: Whether variable is to be protected in parallel sequences.
|
||||
"""
|
||||
if self._local_data is not None and key in self._local_data:
|
||||
self._local_data[key] = value
|
||||
return
|
||||
|
||||
if self._parent is None:
|
||||
assert self._local_data is not None # top level always has local data
|
||||
self._local_data[key] = value
|
||||
return
|
||||
|
||||
if self._parallel_data is not None:
|
||||
if parallel_protected:
|
||||
self._parallel_data.protected[key] = value
|
||||
return
|
||||
self._parallel_data.protected.pop(key, None)
|
||||
self._parallel_data.outer_scope_writes[key] = value
|
||||
|
||||
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
|
||||
|
||||
def define_local(self, key: str, value: Any) -> None:
|
||||
"""Define a local variable and assign value to it."""
|
||||
if self._local_data is None:
|
||||
self._local_data = {}
|
||||
self._non_parallel_scope = self._non_parallel_scope.new_child(
|
||||
self._local_data
|
||||
)
|
||||
self._full_scope = self._full_scope.new_child(self._local_data)
|
||||
self._local_data[key] = value
|
||||
|
||||
@property
|
||||
def data(self) -> Mapping[str, Any]: # type: ignore[override]
|
||||
"""Return variables in full scope.
|
||||
|
||||
Defined here for UserDict compatibility.
|
||||
"""
|
||||
return self._full_scope
|
||||
|
||||
@property
|
||||
def non_parallel_scope(self) -> Mapping[str, Any]:
|
||||
"""Return variables in non-parallel scope."""
|
||||
return self._non_parallel_scope
|
||||
|
||||
@property
|
||||
def local_scope(self) -> Mapping[str, Any]:
|
||||
"""Return variables in local scope."""
|
||||
return self._local_data if self._local_data is not None else {}
|
||||
|
@ -452,6 +452,68 @@ async def test_service_response_data_errors(
|
||||
await script_obj.async_run(context=context)
|
||||
|
||||
|
||||
async def test_calling_service_response_data_in_scopes(hass: HomeAssistant) -> None:
|
||||
"""Test response variable is still set after scopes end."""
|
||||
expected_var = {"data": "value-12345"}
|
||||
|
||||
def mock_service(call: ServiceCall) -> ServiceResponse:
|
||||
"""Mock service call."""
|
||||
if call.return_response:
|
||||
return expected_var
|
||||
return None
|
||||
|
||||
hass.services.async_register(
|
||||
"test", "script", mock_service, supports_response=SupportsResponse.OPTIONAL
|
||||
)
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{
|
||||
"parallel": [
|
||||
{
|
||||
"alias": "Sequential group",
|
||||
"sequence": [
|
||||
{
|
||||
"alias": "variables",
|
||||
"variables": {"state": "off"},
|
||||
},
|
||||
{
|
||||
"alias": "service step1",
|
||||
"action": "test.script",
|
||||
"response_variable": "my_response",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
|
||||
result = await script_obj.async_run(context=Context())
|
||||
|
||||
assert result.variables["my_response"] == expected_var
|
||||
|
||||
expected_trace = {
|
||||
"0": [{"variables": {"my_response": expected_var}}],
|
||||
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
|
||||
"0/parallel/0/sequence/1": [
|
||||
{
|
||||
"result": {
|
||||
"params": {
|
||||
"domain": "test",
|
||||
"service": "script",
|
||||
"service_data": {},
|
||||
"target": {},
|
||||
},
|
||||
"running_script": False,
|
||||
},
|
||||
"variables": {"my_response": expected_var},
|
||||
}
|
||||
],
|
||||
}
|
||||
assert_action_trace(expected_trace)
|
||||
|
||||
|
||||
async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
|
||||
"""Test the calling of a service with a data_template with a templated key."""
|
||||
context = Context()
|
||||
@ -1706,6 +1768,90 @@ async def test_wait_variables_out(hass: HomeAssistant, mode, action_type) -> Non
|
||||
assert float(remaining) == 0.0
|
||||
|
||||
|
||||
async def test_wait_in_sequence(hass: HomeAssistant) -> None:
|
||||
"""Test wait variable is still set after sequence ends."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
[
|
||||
{
|
||||
"alias": "Sequential group",
|
||||
"sequence": [
|
||||
{
|
||||
"alias": "variables",
|
||||
"variables": {"state": "off"},
|
||||
},
|
||||
{
|
||||
"alias": "wait template",
|
||||
"wait_template": "{{ state == 'off' }}",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
|
||||
result = await script_obj.async_run(context=Context())
|
||||
|
||||
expected_var = {"completed": True, "remaining": None}
|
||||
|
||||
assert result.variables["wait"] == expected_var
|
||||
|
||||
expected_trace = {
|
||||
"0": [{"variables": {"wait": expected_var}}],
|
||||
"0/sequence/0": [{"variables": {"state": "off"}}],
|
||||
"0/sequence/1": [
|
||||
{
|
||||
"result": {"wait": expected_var},
|
||||
"variables": {"wait": expected_var},
|
||||
}
|
||||
],
|
||||
}
|
||||
assert_action_trace(expected_trace)
|
||||
|
||||
|
||||
async def test_wait_in_parallel(hass: HomeAssistant) -> None:
|
||||
"""Test wait variable is not set after parallel ends."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{
|
||||
"parallel": [
|
||||
{
|
||||
"alias": "Sequential group",
|
||||
"sequence": [
|
||||
{
|
||||
"alias": "variables",
|
||||
"variables": {"state": "off"},
|
||||
},
|
||||
{
|
||||
"alias": "wait template",
|
||||
"wait_template": "{{ state == 'off' }}",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
|
||||
result = await script_obj.async_run(context=Context())
|
||||
|
||||
expected_var = {"completed": True, "remaining": None}
|
||||
|
||||
assert "wait" not in result.variables
|
||||
|
||||
expected_trace = {
|
||||
"0": [{}],
|
||||
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
|
||||
"0/parallel/0/sequence/1": [
|
||||
{
|
||||
"result": {"wait": expected_var},
|
||||
"variables": {"wait": expected_var},
|
||||
}
|
||||
],
|
||||
}
|
||||
assert_action_trace(expected_trace)
|
||||
|
||||
|
||||
async def test_wait_for_trigger_bad(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
|
@ -5,12 +5,13 @@ import pytest
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.script_variables import ScriptRunVariables, ScriptVariables
|
||||
|
||||
|
||||
async def test_static_vars() -> None:
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
var = ScriptVariables(orig)
|
||||
rendered = var.async_render(None, None)
|
||||
assert rendered is not orig
|
||||
assert rendered == orig
|
||||
@ -20,31 +21,28 @@ async def test_static_vars_run_args() -> None:
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
orig_copy = dict(orig)
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
var = ScriptVariables(orig)
|
||||
rendered = var.async_render(None, {"hello": "override", "run": "var"})
|
||||
assert rendered == {"hello": "override", "run": "var"}
|
||||
# Make sure we don't change original vars
|
||||
assert orig == orig_copy
|
||||
|
||||
|
||||
async def test_static_vars_no_default() -> None:
|
||||
async def test_static_vars_simple() -> None:
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
rendered = var.async_render(None, None, render_as_defaults=False)
|
||||
assert rendered is not orig
|
||||
assert rendered == orig
|
||||
var = ScriptVariables(orig)
|
||||
rendered = var.async_simple_render({})
|
||||
assert rendered is orig
|
||||
|
||||
|
||||
async def test_static_vars_run_args_no_default() -> None:
|
||||
async def test_static_vars_run_args_simple() -> None:
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
orig_copy = dict(orig)
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
rendered = var.async_render(
|
||||
None, {"hello": "override", "run": "var"}, render_as_defaults=False
|
||||
)
|
||||
assert rendered == {"hello": "world", "run": "var"}
|
||||
var = ScriptVariables(orig)
|
||||
rendered = var.async_simple_render({"hello": "override", "run": "var"})
|
||||
assert rendered is orig
|
||||
# Make sure we don't change original vars
|
||||
assert orig == orig_copy
|
||||
|
||||
@ -78,14 +76,14 @@ async def test_template_vars_run_args(hass: HomeAssistant) -> None:
|
||||
}
|
||||
|
||||
|
||||
async def test_template_vars_no_default(hass: HomeAssistant) -> None:
|
||||
async def test_template_vars_simple(hass: HomeAssistant) -> None:
|
||||
"""Test template vars."""
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
|
||||
rendered = var.async_render(hass, None, render_as_defaults=False)
|
||||
rendered = var.async_simple_render({})
|
||||
assert rendered == {"hello": 2}
|
||||
|
||||
|
||||
async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
|
||||
async def test_template_vars_run_args_simple(hass: HomeAssistant) -> None:
|
||||
"""Test template vars."""
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(
|
||||
{
|
||||
@ -93,16 +91,13 @@ async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
|
||||
"something_2": "{{ run_var_ex + 1 }}",
|
||||
}
|
||||
)
|
||||
rendered = var.async_render(
|
||||
hass,
|
||||
rendered = var.async_simple_render(
|
||||
{
|
||||
"run_var_ex": 5,
|
||||
"something_2": 1,
|
||||
},
|
||||
render_as_defaults=False,
|
||||
}
|
||||
)
|
||||
assert rendered == {
|
||||
"run_var_ex": 5,
|
||||
"something": 6,
|
||||
"something_2": 6,
|
||||
}
|
||||
@ -113,3 +108,90 @@ async def test_template_vars_error(hass: HomeAssistant) -> None:
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
|
||||
with pytest.raises(TemplateError):
|
||||
var.async_render(hass, None)
|
||||
|
||||
|
||||
async def test_script_vars_exit_top_level() -> None:
|
||||
"""Test exiting top level script run variables."""
|
||||
script_vars = ScriptRunVariables.create_top_level()
|
||||
with pytest.raises(ValueError):
|
||||
script_vars.exit_scope()
|
||||
|
||||
|
||||
async def test_script_vars_delete_var() -> None:
|
||||
"""Test deleting from script run variables."""
|
||||
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 2})
|
||||
with pytest.raises(TypeError):
|
||||
del script_vars["x"]
|
||||
with pytest.raises(TypeError):
|
||||
script_vars.pop("y")
|
||||
assert script_vars._full_scope == {"x": 1, "y": 2}
|
||||
|
||||
|
||||
async def test_script_vars_scopes() -> None:
|
||||
"""Test script run variables scopes."""
|
||||
script_vars = ScriptRunVariables.create_top_level()
|
||||
script_vars["x"] = 1
|
||||
script_vars["y"] = 1
|
||||
assert script_vars["x"] == 1
|
||||
assert script_vars["y"] == 1
|
||||
|
||||
script_vars_2 = script_vars.enter_scope()
|
||||
script_vars_2.define_local("x", 2)
|
||||
assert script_vars_2["x"] == 2
|
||||
assert script_vars_2["y"] == 1
|
||||
|
||||
script_vars_3 = script_vars_2.enter_scope()
|
||||
script_vars_3["x"] = 3
|
||||
script_vars_3["y"] = 3
|
||||
assert script_vars_3["x"] == 3
|
||||
assert script_vars_3["y"] == 3
|
||||
|
||||
script_vars_4 = script_vars_3.enter_scope()
|
||||
assert script_vars_4["x"] == 3
|
||||
assert script_vars_4["y"] == 3
|
||||
|
||||
assert script_vars_4.exit_scope() is script_vars_3
|
||||
|
||||
assert script_vars_3._full_scope == {"x": 3, "y": 3}
|
||||
assert script_vars_3.local_scope == {}
|
||||
|
||||
assert script_vars_3.exit_scope() is script_vars_2
|
||||
|
||||
assert script_vars_2._full_scope == {"x": 3, "y": 3}
|
||||
assert script_vars_2.local_scope == {"x": 3}
|
||||
|
||||
assert script_vars_2.exit_scope() is script_vars
|
||||
|
||||
assert script_vars._full_scope == {"x": 1, "y": 3}
|
||||
assert script_vars.local_scope == {"x": 1, "y": 3}
|
||||
|
||||
|
||||
async def test_script_vars_parallel() -> None:
|
||||
"""Test script run variables parallel support."""
|
||||
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 1, "z": 1})
|
||||
|
||||
script_vars_2a = script_vars.enter_scope(parallel=True)
|
||||
script_vars_3a = script_vars_2a.enter_scope()
|
||||
|
||||
script_vars_2b = script_vars.enter_scope(parallel=True)
|
||||
script_vars_3b = script_vars_2b.enter_scope()
|
||||
|
||||
script_vars_3a["x"] = "a"
|
||||
script_vars_3a.assign_parallel_protected("y", "a")
|
||||
|
||||
script_vars_3b["x"] = "b"
|
||||
script_vars_3b.assign_parallel_protected("y", "b")
|
||||
|
||||
assert script_vars_3a._full_scope == {"x": "b", "y": "a", "z": 1}
|
||||
assert script_vars_3a.non_parallel_scope == {"x": "a", "y": "a"}
|
||||
|
||||
assert script_vars_3b._full_scope == {"x": "b", "y": "b", "z": 1}
|
||||
assert script_vars_3b.non_parallel_scope == {"x": "b", "y": "b"}
|
||||
|
||||
assert script_vars_3a.exit_scope() is script_vars_2a
|
||||
assert script_vars_2a.exit_scope() is script_vars
|
||||
assert script_vars_3b.exit_scope() is script_vars_2b
|
||||
assert script_vars_2b.exit_scope() is script_vars
|
||||
|
||||
assert script_vars._full_scope == {"x": "b", "y": 1, "z": 1}
|
||||
assert script_vars.local_scope == {"x": "b", "y": 1, "z": 1}
|
||||
|
Loading…
x
Reference in New Issue
Block a user