mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Fix variable scopes in scripts (#138883)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
bd80a78848
commit
b964bc58be
@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any, Literal, TypedDict, cast, overload
|
from typing import Any, Literal, TypedDict, cast, overload
|
||||||
|
|
||||||
import async_interrupt
|
import async_interrupt
|
||||||
@ -90,7 +89,7 @@ from . import condition, config_validation as cv, service, template
|
|||||||
from .condition import ConditionCheckerType, trace_condition_function
|
from .condition import ConditionCheckerType, trace_condition_function
|
||||||
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
||||||
from .event import async_call_later, async_track_template
|
from .event import async_call_later, async_track_template
|
||||||
from .script_variables import ScriptVariables
|
from .script_variables import ScriptRunVariables, ScriptVariables
|
||||||
from .template import Template
|
from .template import Template
|
||||||
from .trace import (
|
from .trace import (
|
||||||
TraceElement,
|
TraceElement,
|
||||||
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
|
|||||||
future.set_result(None)
|
future.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
|
def action_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
|
||||||
"""Append a TraceElement to trace[path]."""
|
"""Append a TraceElement to trace[path]."""
|
||||||
trace_element = TraceElement(variables, path)
|
trace_element = TraceElement(variables, path)
|
||||||
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
|
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
|
||||||
@ -189,7 +188,7 @@ async def trace_action(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
script_run: _ScriptRun,
|
script_run: _ScriptRun,
|
||||||
stop: asyncio.Future[None],
|
stop: asyncio.Future[None],
|
||||||
variables: dict[str, Any],
|
variables: TemplateVarsType,
|
||||||
) -> AsyncGenerator[TraceElement]:
|
) -> AsyncGenerator[TraceElement]:
|
||||||
"""Trace action execution."""
|
"""Trace action execution."""
|
||||||
path = trace_path_get()
|
path = trace_path_get()
|
||||||
@ -411,7 +410,7 @@ class _ScriptRun:
|
|||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
script: Script,
|
script: Script,
|
||||||
variables: dict[str, Any],
|
variables: ScriptRunVariables,
|
||||||
context: Context | None,
|
context: Context | None,
|
||||||
log_exceptions: bool,
|
log_exceptions: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -485,14 +484,16 @@ class _ScriptRun:
|
|||||||
script_stack.pop()
|
script_stack.pop()
|
||||||
self._finish()
|
self._finish()
|
||||||
|
|
||||||
return ScriptRunResult(self._conversation_response, response, self._variables)
|
return ScriptRunResult(
|
||||||
|
self._conversation_response, response, self._variables.local_scope
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_step(self, log_exceptions: bool) -> None:
|
async def _async_step(self, log_exceptions: bool) -> None:
|
||||||
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
|
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
|
||||||
|
|
||||||
with trace_path(str(self._step)):
|
with trace_path(str(self._step)):
|
||||||
async with trace_action(
|
async with trace_action(
|
||||||
self._hass, self, self._stop, self._variables
|
self._hass, self, self._stop, self._variables.non_parallel_scope
|
||||||
) as trace_element:
|
) as trace_element:
|
||||||
if self._stop.done():
|
if self._stop.done():
|
||||||
return
|
return
|
||||||
@ -526,7 +527,7 @@ class _ScriptRun:
|
|||||||
ex, continue_on_error, self._log_exceptions or log_exceptions
|
ex, continue_on_error, self._log_exceptions or log_exceptions
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
trace_element.update_variables(self._variables)
|
trace_element.update_variables(self._variables.non_parallel_scope)
|
||||||
|
|
||||||
def _finish(self) -> None:
|
def _finish(self) -> None:
|
||||||
self._script._runs.remove(self) # noqa: SLF001
|
self._script._runs.remove(self) # noqa: SLF001
|
||||||
@ -624,11 +625,16 @@ class _ScriptRun:
|
|||||||
except ScriptStoppedError as ex:
|
except ScriptStoppedError as ex:
|
||||||
raise asyncio.CancelledError from ex
|
raise asyncio.CancelledError from ex
|
||||||
|
|
||||||
async def _async_run_script(self, script: Script) -> None:
|
async def _async_run_script(
|
||||||
|
self, script: Script, *, parallel: bool = False
|
||||||
|
) -> None:
|
||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
result = await self._async_run_long_action(
|
result = await self._async_run_long_action(
|
||||||
self._hass.async_create_task_internal(
|
self._hass.async_create_task_internal(
|
||||||
script.async_run(self._variables, self._context), eager_start=True
|
script.async_run(
|
||||||
|
self._variables.enter_scope(parallel=parallel), self._context
|
||||||
|
),
|
||||||
|
eager_start=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if result and result.conversation_response is not UNDEFINED:
|
if result and result.conversation_response is not UNDEFINED:
|
||||||
@ -647,7 +653,7 @@ class _ScriptRun:
|
|||||||
"""Run a script with a trace path."""
|
"""Run a script with a trace path."""
|
||||||
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
|
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
|
||||||
with trace_path([str(idx), "sequence"]):
|
with trace_path([str(idx), "sequence"]):
|
||||||
await self._async_run_script(script)
|
await self._async_run_script(script, parallel=True)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
|
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
|
||||||
@ -760,14 +766,11 @@ class _ScriptRun:
|
|||||||
with trace_path("else"):
|
with trace_path("else"):
|
||||||
await self._async_run_script(if_data["if_else"])
|
await self._async_run_script(if_data["if_else"])
|
||||||
|
|
||||||
@async_trace_path("repeat")
|
async def _async_do_step_repeat(self) -> None: # noqa: C901
|
||||||
async def _async_step_repeat(self) -> None: # noqa: C901
|
"""Repeat a sequence helper."""
|
||||||
"""Repeat a sequence."""
|
|
||||||
description = self._action.get(CONF_ALIAS, "sequence")
|
description = self._action.get(CONF_ALIAS, "sequence")
|
||||||
repeat = self._action[CONF_REPEAT]
|
repeat = self._action[CONF_REPEAT]
|
||||||
|
|
||||||
saved_repeat_vars = self._variables.get("repeat")
|
|
||||||
|
|
||||||
def set_repeat_var(
|
def set_repeat_var(
|
||||||
iteration: int, count: int | None = None, item: Any = None
|
iteration: int, count: int | None = None, item: Any = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -776,7 +779,7 @@ class _ScriptRun:
|
|||||||
repeat_vars["last"] = iteration == count
|
repeat_vars["last"] = iteration == count
|
||||||
if item is not None:
|
if item is not None:
|
||||||
repeat_vars["item"] = item
|
repeat_vars["item"] = item
|
||||||
self._variables["repeat"] = repeat_vars
|
self._variables.define_local("repeat", repeat_vars)
|
||||||
|
|
||||||
script = self._script._get_repeat_script(self._step) # noqa: SLF001
|
script = self._script._get_repeat_script(self._step) # noqa: SLF001
|
||||||
warned_too_many_loops = False
|
warned_too_many_loops = False
|
||||||
@ -927,10 +930,14 @@ class _ScriptRun:
|
|||||||
# while all the cpu time is consumed.
|
# while all the cpu time is consumed.
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
if saved_repeat_vars:
|
@async_trace_path("repeat")
|
||||||
self._variables["repeat"] = saved_repeat_vars
|
async def _async_step_repeat(self) -> None:
|
||||||
else:
|
"""Repeat a sequence."""
|
||||||
self._variables.pop("repeat", None) # Not set if count = 0
|
self._variables = self._variables.enter_scope()
|
||||||
|
try:
|
||||||
|
await self._async_do_step_repeat()
|
||||||
|
finally:
|
||||||
|
self._variables = self._variables.exit_scope()
|
||||||
|
|
||||||
### Stop actions ###
|
### Stop actions ###
|
||||||
|
|
||||||
@ -959,11 +966,12 @@ class _ScriptRun:
|
|||||||
## Variable actions ##
|
## Variable actions ##
|
||||||
|
|
||||||
async def _async_step_variables(self) -> None:
|
async def _async_step_variables(self) -> None:
|
||||||
"""Set a variable value."""
|
"""Define a local variable."""
|
||||||
self._step_log("setting variables")
|
self._step_log("defining local variables")
|
||||||
self._variables = self._action[CONF_VARIABLES].async_render(
|
for key, value in (
|
||||||
self._hass, self._variables, render_as_defaults=False
|
self._action[CONF_VARIABLES].async_simple_render(self._variables).items()
|
||||||
)
|
):
|
||||||
|
self._variables.define_local(key, value)
|
||||||
|
|
||||||
## External actions ##
|
## External actions ##
|
||||||
|
|
||||||
@ -1016,7 +1024,7 @@ class _ScriptRun:
|
|||||||
"""Perform the device automation specified in the action."""
|
"""Perform the device automation specified in the action."""
|
||||||
self._step_log("device automation")
|
self._step_log("device automation")
|
||||||
await device_action.async_call_action_from_config(
|
await device_action.async_call_action_from_config(
|
||||||
self._hass, self._action, self._variables, self._context
|
self._hass, self._action, dict(self._variables), self._context
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_step_event(self) -> None:
|
async def _async_step_event(self) -> None:
|
||||||
@ -1189,12 +1197,15 @@ class _ScriptRun:
|
|||||||
|
|
||||||
self._step_log("wait for trigger", timeout)
|
self._step_log("wait for trigger", timeout)
|
||||||
|
|
||||||
variables = {**self._variables}
|
variables = dict(self._variables)
|
||||||
self._variables["wait"] = {
|
self._variables.assign_parallel_protected(
|
||||||
"remaining": timeout,
|
"wait",
|
||||||
"completed": False,
|
{
|
||||||
"trigger": None,
|
"remaining": timeout,
|
||||||
}
|
"completed": False,
|
||||||
|
"trigger": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
trace_set_result(wait=self._variables["wait"])
|
trace_set_result(wait=self._variables["wait"])
|
||||||
|
|
||||||
if timeout == 0:
|
if timeout == 0:
|
||||||
@ -1240,7 +1251,9 @@ class _ScriptRun:
|
|||||||
timeout = self._get_timeout_seconds_from_action()
|
timeout = self._get_timeout_seconds_from_action()
|
||||||
self._step_log("wait template", timeout)
|
self._step_log("wait template", timeout)
|
||||||
|
|
||||||
self._variables["wait"] = {"remaining": timeout, "completed": False}
|
self._variables.assign_parallel_protected(
|
||||||
|
"wait", {"remaining": timeout, "completed": False}
|
||||||
|
)
|
||||||
trace_set_result(wait=self._variables["wait"])
|
trace_set_result(wait=self._variables["wait"])
|
||||||
|
|
||||||
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
||||||
@ -1369,7 +1382,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
|
type _VarsType = dict[str, Any] | Mapping[str, Any] | ScriptRunVariables
|
||||||
|
|
||||||
|
|
||||||
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
||||||
@ -1407,7 +1420,7 @@ class ScriptRunResult:
|
|||||||
|
|
||||||
conversation_response: str | None | UndefinedType
|
conversation_response: str | None | UndefinedType
|
||||||
service_response: ServiceResponse
|
service_response: ServiceResponse
|
||||||
variables: dict[str, Any]
|
variables: Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
@ -1422,7 +1435,6 @@ class Script:
|
|||||||
*,
|
*,
|
||||||
# Used in "Running <running_description>" log message
|
# Used in "Running <running_description>" log message
|
||||||
change_listener: Callable[[], Any] | None = None,
|
change_listener: Callable[[], Any] | None = None,
|
||||||
copy_variables: bool = False,
|
|
||||||
log_exceptions: bool = True,
|
log_exceptions: bool = True,
|
||||||
logger: logging.Logger | None = None,
|
logger: logging.Logger | None = None,
|
||||||
max_exceeded: str = DEFAULT_MAX_EXCEEDED,
|
max_exceeded: str = DEFAULT_MAX_EXCEEDED,
|
||||||
@ -1476,8 +1488,6 @@ class Script:
|
|||||||
self._parallel_scripts: dict[int, list[Script]] = {}
|
self._parallel_scripts: dict[int, list[Script]] = {}
|
||||||
self._sequence_scripts: dict[int, Script] = {}
|
self._sequence_scripts: dict[int, Script] = {}
|
||||||
self.variables = variables
|
self.variables = variables
|
||||||
self._variables_dynamic = template.is_complex(variables)
|
|
||||||
self._copy_variables_on_run = copy_variables
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def change_listener(self) -> Callable[..., Any] | None:
|
def change_listener(self) -> Callable[..., Any] | None:
|
||||||
@ -1755,25 +1765,19 @@ class Script:
|
|||||||
if self.top_level:
|
if self.top_level:
|
||||||
if self.variables:
|
if self.variables:
|
||||||
try:
|
try:
|
||||||
variables = self.variables.async_render(
|
run_variables = self.variables.async_render(
|
||||||
self._hass,
|
self._hass,
|
||||||
run_variables,
|
run_variables,
|
||||||
)
|
)
|
||||||
except exceptions.TemplateError as err:
|
except exceptions.TemplateError as err:
|
||||||
self._log("Error rendering variables: %s", err, level=logging.ERROR)
|
self._log("Error rendering variables: %s", err, level=logging.ERROR)
|
||||||
raise
|
raise
|
||||||
elif run_variables:
|
|
||||||
variables = dict(run_variables)
|
|
||||||
else:
|
|
||||||
variables = {}
|
|
||||||
|
|
||||||
|
variables = ScriptRunVariables.create_top_level(run_variables)
|
||||||
variables["context"] = context
|
variables["context"] = context
|
||||||
elif self._copy_variables_on_run:
|
|
||||||
# This is not the top level script, variables have been turned to a dict
|
|
||||||
variables = cast(dict[str, Any], copy(run_variables))
|
|
||||||
else:
|
else:
|
||||||
# This is not the top level script, variables have been turned to a dict
|
# This is not the top level script, run_variables is an instance of ScriptRunVariables
|
||||||
variables = cast(dict[str, Any], run_variables)
|
variables = cast(ScriptRunVariables, run_variables)
|
||||||
|
|
||||||
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
||||||
# stop (restart) or wait for (queued) our own script run.
|
# stop (restart) or wait for (queued) our own script run.
|
||||||
@ -1999,7 +2003,6 @@ class Script:
|
|||||||
max_runs=self.max_runs,
|
max_runs=self.max_runs,
|
||||||
logger=self._logger,
|
logger=self._logger,
|
||||||
top_level=False,
|
top_level=False,
|
||||||
copy_variables=True,
|
|
||||||
)
|
)
|
||||||
parallel_script.change_listener = partial(
|
parallel_script.change_listener = partial(
|
||||||
self._chain_change_listener, parallel_script
|
self._chain_change_listener, parallel_script
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import ChainMap, UserDict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
@ -24,30 +26,23 @@ class ScriptVariables:
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
run_variables: Mapping[str, Any] | None,
|
run_variables: Mapping[str, Any] | None,
|
||||||
*,
|
*,
|
||||||
render_as_defaults: bool = True,
|
|
||||||
limited: bool = False,
|
limited: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Render script variables.
|
"""Render script variables.
|
||||||
|
|
||||||
The run variables are used to compute the static variables.
|
The run variables are included in the result.
|
||||||
|
The run variables are used to compute the rendered variable values.
|
||||||
If `render_as_defaults` is True, the run variables will not be overridden.
|
The run variables will not be overridden.
|
||||||
|
The rendering happens one at a time, with previous results influencing the next.
|
||||||
"""
|
"""
|
||||||
if self._has_template is None:
|
if self._has_template is None:
|
||||||
self._has_template = template.is_complex(self.variables)
|
self._has_template = template.is_complex(self.variables)
|
||||||
|
|
||||||
if not self._has_template:
|
if not self._has_template:
|
||||||
if render_as_defaults:
|
rendered_variables = dict(self.variables)
|
||||||
rendered_variables = dict(self.variables)
|
|
||||||
|
|
||||||
if run_variables is not None:
|
if run_variables is not None:
|
||||||
rendered_variables.update(run_variables)
|
rendered_variables.update(run_variables)
|
||||||
else:
|
|
||||||
rendered_variables = (
|
|
||||||
{} if run_variables is None else dict(run_variables)
|
|
||||||
)
|
|
||||||
rendered_variables.update(self.variables)
|
|
||||||
|
|
||||||
return rendered_variables
|
return rendered_variables
|
||||||
|
|
||||||
@ -56,7 +51,7 @@ class ScriptVariables:
|
|||||||
for key, value in self.variables.items():
|
for key, value in self.variables.items():
|
||||||
# We can skip if we're going to override this key with
|
# We can skip if we're going to override this key with
|
||||||
# run variables anyway
|
# run variables anyway
|
||||||
if render_as_defaults and key in rendered_variables:
|
if key in rendered_variables:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
rendered_variables[key] = template.render_complex(
|
rendered_variables[key] = template.render_complex(
|
||||||
@ -65,6 +60,197 @@ class ScriptVariables:
|
|||||||
|
|
||||||
return rendered_variables
|
return rendered_variables
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Render script variables.
|
||||||
|
|
||||||
|
Simply renders the variables, the run variables are not included in the result.
|
||||||
|
The run variables are used to compute the rendered variable values.
|
||||||
|
The rendering happens one at a time, with previous results influencing the next.
|
||||||
|
"""
|
||||||
|
if self._has_template is None:
|
||||||
|
self._has_template = template.is_complex(self.variables)
|
||||||
|
|
||||||
|
if not self._has_template:
|
||||||
|
return self.variables
|
||||||
|
|
||||||
|
run_variables = dict(run_variables)
|
||||||
|
rendered_variables = {}
|
||||||
|
|
||||||
|
for key, value in self.variables.items():
|
||||||
|
rendered_variable = template.render_complex(value, run_variables)
|
||||||
|
rendered_variables[key] = rendered_variable
|
||||||
|
run_variables[key] = rendered_variable
|
||||||
|
|
||||||
|
return rendered_variables
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
"""Return dict version of this class."""
|
"""Return dict version of this class."""
|
||||||
return self.variables
|
return self.variables
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ParallelData:
|
||||||
|
"""Data used in each parallel sequence."""
|
||||||
|
|
||||||
|
# `protected` is for variables that need special protection in parallel sequences.
|
||||||
|
# What this means is that such a variable defined in one parallel sequence will not be
|
||||||
|
# clobbered by the variable with the same name assigned in another parallel sequence.
|
||||||
|
# It also means that such a variable will not be visible in the outer scope.
|
||||||
|
# Currently the only such variable is `wait`.
|
||||||
|
protected: dict[str, Any] = field(default_factory=dict)
|
||||||
|
# `outer_scope_writes` is for variables that are written to the outer scope from
|
||||||
|
# a parallel sequence. This is used for generating correct traces of changed variables
|
||||||
|
# for each of the parallel sequences, isolating them from one another.
|
||||||
|
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class ScriptRunVariables(UserDict[str, Any]):
|
||||||
|
"""Class to hold script run variables.
|
||||||
|
|
||||||
|
The purpose of this class is to provide proper variable scoping semantics for scripts.
|
||||||
|
Each instance institutes a new local scope, in which variables can be defined.
|
||||||
|
Each instance has a reference to the previous instance, except for the top-level instance.
|
||||||
|
The instances therefore form a chain, in which variable lookup and assignment is performed.
|
||||||
|
The variables defined lower in the chain naturally override those defined higher up.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# _previous is the previous ScriptRunVariables in the chain
|
||||||
|
_previous: ScriptRunVariables | None = None
|
||||||
|
# _parent is the previous non-empty ScriptRunVariables in the chain
|
||||||
|
_parent: ScriptRunVariables | None = None
|
||||||
|
|
||||||
|
# _local_data is the store for local variables
|
||||||
|
_local_data: dict[str, Any] | None = None
|
||||||
|
# _parallel_data is used for each parallel sequence
|
||||||
|
_parallel_data: _ParallelData | None = None
|
||||||
|
|
||||||
|
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
|
||||||
|
_non_parallel_scope: ChainMap[str, Any]
|
||||||
|
# _full_scope includes all scopes (all the way to the top-level)
|
||||||
|
_full_scope: ChainMap[str, Any]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_top_level(
|
||||||
|
cls,
|
||||||
|
initial_data: Mapping[str, Any] | None = None,
|
||||||
|
) -> ScriptRunVariables:
|
||||||
|
"""Create a new top-level ScriptRunVariables."""
|
||||||
|
local_data: dict[str, Any] = {}
|
||||||
|
non_parallel_scope = full_scope = ChainMap(local_data)
|
||||||
|
self = cls(
|
||||||
|
_local_data=local_data,
|
||||||
|
_non_parallel_scope=non_parallel_scope,
|
||||||
|
_full_scope=full_scope,
|
||||||
|
)
|
||||||
|
if initial_data is not None:
|
||||||
|
self.update(initial_data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
|
||||||
|
"""Return a new child scope.
|
||||||
|
|
||||||
|
:param parallel: Whether the new scope starts a parallel sequence.
|
||||||
|
"""
|
||||||
|
if self._local_data is not None or self._parallel_data is not None:
|
||||||
|
parent = self
|
||||||
|
else:
|
||||||
|
parent = cast( # top level always has local data, so we can cast safely
|
||||||
|
ScriptRunVariables, self._parent
|
||||||
|
)
|
||||||
|
|
||||||
|
parallel_data: _ParallelData | None
|
||||||
|
if not parallel:
|
||||||
|
parallel_data = None
|
||||||
|
non_parallel_scope = self._non_parallel_scope
|
||||||
|
full_scope = self._full_scope
|
||||||
|
else:
|
||||||
|
parallel_data = _ParallelData()
|
||||||
|
non_parallel_scope = ChainMap(
|
||||||
|
parallel_data.protected, parallel_data.outer_scope_writes
|
||||||
|
)
|
||||||
|
full_scope = self._full_scope.new_child(parallel_data.protected)
|
||||||
|
|
||||||
|
return ScriptRunVariables(
|
||||||
|
_previous=self,
|
||||||
|
_parent=parent,
|
||||||
|
_parallel_data=parallel_data,
|
||||||
|
_non_parallel_scope=non_parallel_scope,
|
||||||
|
_full_scope=full_scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def exit_scope(self) -> ScriptRunVariables:
|
||||||
|
"""Exit the current scope.
|
||||||
|
|
||||||
|
Does no clean-up, but simply returns the previous scope.
|
||||||
|
"""
|
||||||
|
if self._previous is None:
|
||||||
|
raise ValueError("Cannot exit top-level scope")
|
||||||
|
return self._previous
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""Delete a variable (disallowed)."""
|
||||||
|
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
"""Assign value to a variable."""
|
||||||
|
self._assign(key, value, parallel_protected=False)
|
||||||
|
|
||||||
|
def assign_parallel_protected(self, key: str, value: Any) -> None:
|
||||||
|
"""Assign value to a variable which is to be protected in parallel sequences."""
|
||||||
|
self._assign(key, value, parallel_protected=True)
|
||||||
|
|
||||||
|
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
|
||||||
|
"""Assign value to a variable.
|
||||||
|
|
||||||
|
Value is always assigned to the variable in the nearest scope, in which it is defined.
|
||||||
|
If the variable is not defined at all, it is created in the top-level scope.
|
||||||
|
|
||||||
|
:param parallel_protected: Whether variable is to be protected in parallel sequences.
|
||||||
|
"""
|
||||||
|
if self._local_data is not None and key in self._local_data:
|
||||||
|
self._local_data[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._parent is None:
|
||||||
|
assert self._local_data is not None # top level always has local data
|
||||||
|
self._local_data[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._parallel_data is not None:
|
||||||
|
if parallel_protected:
|
||||||
|
self._parallel_data.protected[key] = value
|
||||||
|
return
|
||||||
|
self._parallel_data.protected.pop(key, None)
|
||||||
|
self._parallel_data.outer_scope_writes[key] = value
|
||||||
|
|
||||||
|
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
|
||||||
|
|
||||||
|
def define_local(self, key: str, value: Any) -> None:
|
||||||
|
"""Define a local variable and assign value to it."""
|
||||||
|
if self._local_data is None:
|
||||||
|
self._local_data = {}
|
||||||
|
self._non_parallel_scope = self._non_parallel_scope.new_child(
|
||||||
|
self._local_data
|
||||||
|
)
|
||||||
|
self._full_scope = self._full_scope.new_child(self._local_data)
|
||||||
|
self._local_data[key] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Mapping[str, Any]: # type: ignore[override]
|
||||||
|
"""Return variables in full scope.
|
||||||
|
|
||||||
|
Defined here for UserDict compatibility.
|
||||||
|
"""
|
||||||
|
return self._full_scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def non_parallel_scope(self) -> Mapping[str, Any]:
|
||||||
|
"""Return variables in non-parallel scope."""
|
||||||
|
return self._non_parallel_scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_scope(self) -> Mapping[str, Any]:
|
||||||
|
"""Return variables in local scope."""
|
||||||
|
return self._local_data if self._local_data is not None else {}
|
||||||
|
@ -452,6 +452,68 @@ async def test_service_response_data_errors(
|
|||||||
await script_obj.async_run(context=context)
|
await script_obj.async_run(context=context)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_calling_service_response_data_in_scopes(hass: HomeAssistant) -> None:
|
||||||
|
"""Test response variable is still set after scopes end."""
|
||||||
|
expected_var = {"data": "value-12345"}
|
||||||
|
|
||||||
|
def mock_service(call: ServiceCall) -> ServiceResponse:
|
||||||
|
"""Mock service call."""
|
||||||
|
if call.return_response:
|
||||||
|
return expected_var
|
||||||
|
return None
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
"test", "script", mock_service, supports_response=SupportsResponse.OPTIONAL
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
{
|
||||||
|
"parallel": [
|
||||||
|
{
|
||||||
|
"alias": "Sequential group",
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"alias": "variables",
|
||||||
|
"variables": {"state": "off"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"alias": "service step1",
|
||||||
|
"action": "test.script",
|
||||||
|
"response_variable": "my_response",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
|
|
||||||
|
result = await script_obj.async_run(context=Context())
|
||||||
|
|
||||||
|
assert result.variables["my_response"] == expected_var
|
||||||
|
|
||||||
|
expected_trace = {
|
||||||
|
"0": [{"variables": {"my_response": expected_var}}],
|
||||||
|
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
|
||||||
|
"0/parallel/0/sequence/1": [
|
||||||
|
{
|
||||||
|
"result": {
|
||||||
|
"params": {
|
||||||
|
"domain": "test",
|
||||||
|
"service": "script",
|
||||||
|
"service_data": {},
|
||||||
|
"target": {},
|
||||||
|
},
|
||||||
|
"running_script": False,
|
||||||
|
},
|
||||||
|
"variables": {"my_response": expected_var},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
assert_action_trace(expected_trace)
|
||||||
|
|
||||||
|
|
||||||
async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
|
async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
|
||||||
"""Test the calling of a service with a data_template with a templated key."""
|
"""Test the calling of a service with a data_template with a templated key."""
|
||||||
context = Context()
|
context = Context()
|
||||||
@ -1706,6 +1768,90 @@ async def test_wait_variables_out(hass: HomeAssistant, mode, action_type) -> Non
|
|||||||
assert float(remaining) == 0.0
|
assert float(remaining) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wait_in_sequence(hass: HomeAssistant) -> None:
|
||||||
|
"""Test wait variable is still set after sequence ends."""
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"alias": "Sequential group",
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"alias": "variables",
|
||||||
|
"variables": {"state": "off"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"alias": "wait template",
|
||||||
|
"wait_template": "{{ state == 'off' }}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
|
|
||||||
|
result = await script_obj.async_run(context=Context())
|
||||||
|
|
||||||
|
expected_var = {"completed": True, "remaining": None}
|
||||||
|
|
||||||
|
assert result.variables["wait"] == expected_var
|
||||||
|
|
||||||
|
expected_trace = {
|
||||||
|
"0": [{"variables": {"wait": expected_var}}],
|
||||||
|
"0/sequence/0": [{"variables": {"state": "off"}}],
|
||||||
|
"0/sequence/1": [
|
||||||
|
{
|
||||||
|
"result": {"wait": expected_var},
|
||||||
|
"variables": {"wait": expected_var},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
assert_action_trace(expected_trace)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wait_in_parallel(hass: HomeAssistant) -> None:
|
||||||
|
"""Test wait variable is not set after parallel ends."""
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
{
|
||||||
|
"parallel": [
|
||||||
|
{
|
||||||
|
"alias": "Sequential group",
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"alias": "variables",
|
||||||
|
"variables": {"state": "off"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"alias": "wait template",
|
||||||
|
"wait_template": "{{ state == 'off' }}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
|
|
||||||
|
result = await script_obj.async_run(context=Context())
|
||||||
|
|
||||||
|
expected_var = {"completed": True, "remaining": None}
|
||||||
|
|
||||||
|
assert "wait" not in result.variables
|
||||||
|
|
||||||
|
expected_trace = {
|
||||||
|
"0": [{}],
|
||||||
|
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
|
||||||
|
"0/parallel/0/sequence/1": [
|
||||||
|
{
|
||||||
|
"result": {"wait": expected_var},
|
||||||
|
"variables": {"wait": expected_var},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
assert_action_trace(expected_trace)
|
||||||
|
|
||||||
|
|
||||||
async def test_wait_for_trigger_bad(
|
async def test_wait_for_trigger_bad(
|
||||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -5,12 +5,13 @@ import pytest
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.exceptions import TemplateError
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.script_variables import ScriptRunVariables, ScriptVariables
|
||||||
|
|
||||||
|
|
||||||
async def test_static_vars() -> None:
|
async def test_static_vars() -> None:
|
||||||
"""Test static vars."""
|
"""Test static vars."""
|
||||||
orig = {"hello": "world"}
|
orig = {"hello": "world"}
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
var = ScriptVariables(orig)
|
||||||
rendered = var.async_render(None, None)
|
rendered = var.async_render(None, None)
|
||||||
assert rendered is not orig
|
assert rendered is not orig
|
||||||
assert rendered == orig
|
assert rendered == orig
|
||||||
@ -20,31 +21,28 @@ async def test_static_vars_run_args() -> None:
|
|||||||
"""Test static vars."""
|
"""Test static vars."""
|
||||||
orig = {"hello": "world"}
|
orig = {"hello": "world"}
|
||||||
orig_copy = dict(orig)
|
orig_copy = dict(orig)
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
var = ScriptVariables(orig)
|
||||||
rendered = var.async_render(None, {"hello": "override", "run": "var"})
|
rendered = var.async_render(None, {"hello": "override", "run": "var"})
|
||||||
assert rendered == {"hello": "override", "run": "var"}
|
assert rendered == {"hello": "override", "run": "var"}
|
||||||
# Make sure we don't change original vars
|
# Make sure we don't change original vars
|
||||||
assert orig == orig_copy
|
assert orig == orig_copy
|
||||||
|
|
||||||
|
|
||||||
async def test_static_vars_no_default() -> None:
|
async def test_static_vars_simple() -> None:
|
||||||
"""Test static vars."""
|
"""Test static vars."""
|
||||||
orig = {"hello": "world"}
|
orig = {"hello": "world"}
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
var = ScriptVariables(orig)
|
||||||
rendered = var.async_render(None, None, render_as_defaults=False)
|
rendered = var.async_simple_render({})
|
||||||
assert rendered is not orig
|
assert rendered is orig
|
||||||
assert rendered == orig
|
|
||||||
|
|
||||||
|
|
||||||
async def test_static_vars_run_args_no_default() -> None:
|
async def test_static_vars_run_args_simple() -> None:
|
||||||
"""Test static vars."""
|
"""Test static vars."""
|
||||||
orig = {"hello": "world"}
|
orig = {"hello": "world"}
|
||||||
orig_copy = dict(orig)
|
orig_copy = dict(orig)
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
var = ScriptVariables(orig)
|
||||||
rendered = var.async_render(
|
rendered = var.async_simple_render({"hello": "override", "run": "var"})
|
||||||
None, {"hello": "override", "run": "var"}, render_as_defaults=False
|
assert rendered is orig
|
||||||
)
|
|
||||||
assert rendered == {"hello": "world", "run": "var"}
|
|
||||||
# Make sure we don't change original vars
|
# Make sure we don't change original vars
|
||||||
assert orig == orig_copy
|
assert orig == orig_copy
|
||||||
|
|
||||||
@ -78,14 +76,14 @@ async def test_template_vars_run_args(hass: HomeAssistant) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_template_vars_no_default(hass: HomeAssistant) -> None:
|
async def test_template_vars_simple(hass: HomeAssistant) -> None:
|
||||||
"""Test template vars."""
|
"""Test template vars."""
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
|
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
|
||||||
rendered = var.async_render(hass, None, render_as_defaults=False)
|
rendered = var.async_simple_render({})
|
||||||
assert rendered == {"hello": 2}
|
assert rendered == {"hello": 2}
|
||||||
|
|
||||||
|
|
||||||
async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
|
async def test_template_vars_run_args_simple(hass: HomeAssistant) -> None:
|
||||||
"""Test template vars."""
|
"""Test template vars."""
|
||||||
var = cv.SCRIPT_VARIABLES_SCHEMA(
|
var = cv.SCRIPT_VARIABLES_SCHEMA(
|
||||||
{
|
{
|
||||||
@ -93,16 +91,13 @@ async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
|
|||||||
"something_2": "{{ run_var_ex + 1 }}",
|
"something_2": "{{ run_var_ex + 1 }}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rendered = var.async_render(
|
rendered = var.async_simple_render(
|
||||||
hass,
|
|
||||||
{
|
{
|
||||||
"run_var_ex": 5,
|
"run_var_ex": 5,
|
||||||
"something_2": 1,
|
"something_2": 1,
|
||||||
},
|
}
|
||||||
render_as_defaults=False,
|
|
||||||
)
|
)
|
||||||
assert rendered == {
|
assert rendered == {
|
||||||
"run_var_ex": 5,
|
|
||||||
"something": 6,
|
"something": 6,
|
||||||
"something_2": 6,
|
"something_2": 6,
|
||||||
}
|
}
|
||||||
@ -113,3 +108,90 @@ async def test_template_vars_error(hass: HomeAssistant) -> None:
|
|||||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
|
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
|
||||||
with pytest.raises(TemplateError):
|
with pytest.raises(TemplateError):
|
||||||
var.async_render(hass, None)
|
var.async_render(hass, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_script_vars_exit_top_level() -> None:
|
||||||
|
"""Test exiting top level script run variables."""
|
||||||
|
script_vars = ScriptRunVariables.create_top_level()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
script_vars.exit_scope()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_script_vars_delete_var() -> None:
|
||||||
|
"""Test deleting from script run variables."""
|
||||||
|
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 2})
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
del script_vars["x"]
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
script_vars.pop("y")
|
||||||
|
assert script_vars._full_scope == {"x": 1, "y": 2}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_script_vars_scopes() -> None:
|
||||||
|
"""Test script run variables scopes."""
|
||||||
|
script_vars = ScriptRunVariables.create_top_level()
|
||||||
|
script_vars["x"] = 1
|
||||||
|
script_vars["y"] = 1
|
||||||
|
assert script_vars["x"] == 1
|
||||||
|
assert script_vars["y"] == 1
|
||||||
|
|
||||||
|
script_vars_2 = script_vars.enter_scope()
|
||||||
|
script_vars_2.define_local("x", 2)
|
||||||
|
assert script_vars_2["x"] == 2
|
||||||
|
assert script_vars_2["y"] == 1
|
||||||
|
|
||||||
|
script_vars_3 = script_vars_2.enter_scope()
|
||||||
|
script_vars_3["x"] = 3
|
||||||
|
script_vars_3["y"] = 3
|
||||||
|
assert script_vars_3["x"] == 3
|
||||||
|
assert script_vars_3["y"] == 3
|
||||||
|
|
||||||
|
script_vars_4 = script_vars_3.enter_scope()
|
||||||
|
assert script_vars_4["x"] == 3
|
||||||
|
assert script_vars_4["y"] == 3
|
||||||
|
|
||||||
|
assert script_vars_4.exit_scope() is script_vars_3
|
||||||
|
|
||||||
|
assert script_vars_3._full_scope == {"x": 3, "y": 3}
|
||||||
|
assert script_vars_3.local_scope == {}
|
||||||
|
|
||||||
|
assert script_vars_3.exit_scope() is script_vars_2
|
||||||
|
|
||||||
|
assert script_vars_2._full_scope == {"x": 3, "y": 3}
|
||||||
|
assert script_vars_2.local_scope == {"x": 3}
|
||||||
|
|
||||||
|
assert script_vars_2.exit_scope() is script_vars
|
||||||
|
|
||||||
|
assert script_vars._full_scope == {"x": 1, "y": 3}
|
||||||
|
assert script_vars.local_scope == {"x": 1, "y": 3}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_script_vars_parallel() -> None:
|
||||||
|
"""Test script run variables parallel support."""
|
||||||
|
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 1, "z": 1})
|
||||||
|
|
||||||
|
script_vars_2a = script_vars.enter_scope(parallel=True)
|
||||||
|
script_vars_3a = script_vars_2a.enter_scope()
|
||||||
|
|
||||||
|
script_vars_2b = script_vars.enter_scope(parallel=True)
|
||||||
|
script_vars_3b = script_vars_2b.enter_scope()
|
||||||
|
|
||||||
|
script_vars_3a["x"] = "a"
|
||||||
|
script_vars_3a.assign_parallel_protected("y", "a")
|
||||||
|
|
||||||
|
script_vars_3b["x"] = "b"
|
||||||
|
script_vars_3b.assign_parallel_protected("y", "b")
|
||||||
|
|
||||||
|
assert script_vars_3a._full_scope == {"x": "b", "y": "a", "z": 1}
|
||||||
|
assert script_vars_3a.non_parallel_scope == {"x": "a", "y": "a"}
|
||||||
|
|
||||||
|
assert script_vars_3b._full_scope == {"x": "b", "y": "b", "z": 1}
|
||||||
|
assert script_vars_3b.non_parallel_scope == {"x": "b", "y": "b"}
|
||||||
|
|
||||||
|
assert script_vars_3a.exit_scope() is script_vars_2a
|
||||||
|
assert script_vars_2a.exit_scope() is script_vars
|
||||||
|
assert script_vars_3b.exit_scope() is script_vars_2b
|
||||||
|
assert script_vars_2b.exit_scope() is script_vars
|
||||||
|
|
||||||
|
assert script_vars._full_scope == {"x": "b", "y": 1, "z": 1}
|
||||||
|
assert script_vars.local_scope == {"x": "b", "y": 1, "z": 1}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user