Fix variable scopes in scripts (#138883)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Artur Pragacz 2025-02-26 16:19:19 +01:00 committed by GitHub
parent bd80a78848
commit b964bc58be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 504 additions and 87 deletions

View File

@ -12,7 +12,6 @@ from datetime import datetime, timedelta
from functools import partial
import itertools
import logging
from types import MappingProxyType
from typing import Any, Literal, TypedDict, cast, overload
import async_interrupt
@ -90,7 +89,7 @@ from . import condition, config_validation as cv, service, template
from .condition import ConditionCheckerType, trace_condition_function
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
from .event import async_call_later, async_track_template
from .script_variables import ScriptVariables
from .script_variables import ScriptRunVariables, ScriptVariables
from .template import Template
from .trace import (
TraceElement,
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
future.set_result(None)
def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
def action_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
"""Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables, path)
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
@ -189,7 +188,7 @@ async def trace_action(
hass: HomeAssistant,
script_run: _ScriptRun,
stop: asyncio.Future[None],
variables: dict[str, Any],
variables: TemplateVarsType,
) -> AsyncGenerator[TraceElement]:
"""Trace action execution."""
path = trace_path_get()
@ -411,7 +410,7 @@ class _ScriptRun:
self,
hass: HomeAssistant,
script: Script,
variables: dict[str, Any],
variables: ScriptRunVariables,
context: Context | None,
log_exceptions: bool,
) -> None:
@ -485,14 +484,16 @@ class _ScriptRun:
script_stack.pop()
self._finish()
return ScriptRunResult(self._conversation_response, response, self._variables)
return ScriptRunResult(
self._conversation_response, response, self._variables.local_scope
)
async def _async_step(self, log_exceptions: bool) -> None:
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
with trace_path(str(self._step)):
async with trace_action(
self._hass, self, self._stop, self._variables
self._hass, self, self._stop, self._variables.non_parallel_scope
) as trace_element:
if self._stop.done():
return
@ -526,7 +527,7 @@ class _ScriptRun:
ex, continue_on_error, self._log_exceptions or log_exceptions
)
finally:
trace_element.update_variables(self._variables)
trace_element.update_variables(self._variables.non_parallel_scope)
def _finish(self) -> None:
self._script._runs.remove(self) # noqa: SLF001
@ -624,11 +625,16 @@ class _ScriptRun:
except ScriptStoppedError as ex:
raise asyncio.CancelledError from ex
async def _async_run_script(self, script: Script) -> None:
async def _async_run_script(
self, script: Script, *, parallel: bool = False
) -> None:
"""Execute a script."""
result = await self._async_run_long_action(
self._hass.async_create_task_internal(
script.async_run(self._variables, self._context), eager_start=True
script.async_run(
self._variables.enter_scope(parallel=parallel), self._context
),
eager_start=True,
)
)
if result and result.conversation_response is not UNDEFINED:
@ -647,7 +653,7 @@ class _ScriptRun:
"""Run a script with a trace path."""
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
with trace_path([str(idx), "sequence"]):
await self._async_run_script(script)
await self._async_run_script(script, parallel=True)
results = await asyncio.gather(
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
@ -760,14 +766,11 @@ class _ScriptRun:
with trace_path("else"):
await self._async_run_script(if_data["if_else"])
@async_trace_path("repeat")
async def _async_step_repeat(self) -> None: # noqa: C901
"""Repeat a sequence."""
async def _async_do_step_repeat(self) -> None: # noqa: C901
"""Repeat a sequence helper."""
description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT]
saved_repeat_vars = self._variables.get("repeat")
def set_repeat_var(
iteration: int, count: int | None = None, item: Any = None
) -> None:
@ -776,7 +779,7 @@ class _ScriptRun:
repeat_vars["last"] = iteration == count
if item is not None:
repeat_vars["item"] = item
self._variables["repeat"] = repeat_vars
self._variables.define_local("repeat", repeat_vars)
script = self._script._get_repeat_script(self._step) # noqa: SLF001
warned_too_many_loops = False
@ -927,10 +930,14 @@ class _ScriptRun:
# while all the cpu time is consumed.
await asyncio.sleep(0)
if saved_repeat_vars:
self._variables["repeat"] = saved_repeat_vars
else:
self._variables.pop("repeat", None) # Not set if count = 0
@async_trace_path("repeat")
async def _async_step_repeat(self) -> None:
"""Repeat a sequence."""
self._variables = self._variables.enter_scope()
try:
await self._async_do_step_repeat()
finally:
self._variables = self._variables.exit_scope()
### Stop actions ###
@ -959,11 +966,12 @@ class _ScriptRun:
## Variable actions ##
async def _async_step_variables(self) -> None:
"""Set a variable value."""
self._step_log("setting variables")
self._variables = self._action[CONF_VARIABLES].async_render(
self._hass, self._variables, render_as_defaults=False
)
"""Define a local variable."""
self._step_log("defining local variables")
for key, value in (
self._action[CONF_VARIABLES].async_simple_render(self._variables).items()
):
self._variables.define_local(key, value)
## External actions ##
@ -1016,7 +1024,7 @@ class _ScriptRun:
"""Perform the device automation specified in the action."""
self._step_log("device automation")
await device_action.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
self._hass, self._action, dict(self._variables), self._context
)
async def _async_step_event(self) -> None:
@ -1189,12 +1197,15 @@ class _ScriptRun:
self._step_log("wait for trigger", timeout)
variables = {**self._variables}
self._variables["wait"] = {
"remaining": timeout,
"completed": False,
"trigger": None,
}
variables = dict(self._variables)
self._variables.assign_parallel_protected(
"wait",
{
"remaining": timeout,
"completed": False,
"trigger": None,
},
)
trace_set_result(wait=self._variables["wait"])
if timeout == 0:
@ -1240,7 +1251,9 @@ class _ScriptRun:
timeout = self._get_timeout_seconds_from_action()
self._step_log("wait template", timeout)
self._variables["wait"] = {"remaining": timeout, "completed": False}
self._variables.assign_parallel_protected(
"wait", {"remaining": timeout, "completed": False}
)
trace_set_result(wait=self._variables["wait"])
wait_template = self._action[CONF_WAIT_TEMPLATE]
@ -1369,7 +1382,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
)
type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
type _VarsType = dict[str, Any] | Mapping[str, Any] | ScriptRunVariables
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
@ -1407,7 +1420,7 @@ class ScriptRunResult:
conversation_response: str | None | UndefinedType
service_response: ServiceResponse
variables: dict[str, Any]
variables: Mapping[str, Any]
class Script:
@ -1422,7 +1435,6 @@ class Script:
*,
# Used in "Running <running_description>" log message
change_listener: Callable[[], Any] | None = None,
copy_variables: bool = False,
log_exceptions: bool = True,
logger: logging.Logger | None = None,
max_exceeded: str = DEFAULT_MAX_EXCEEDED,
@ -1476,8 +1488,6 @@ class Script:
self._parallel_scripts: dict[int, list[Script]] = {}
self._sequence_scripts: dict[int, Script] = {}
self.variables = variables
self._variables_dynamic = template.is_complex(variables)
self._copy_variables_on_run = copy_variables
@property
def change_listener(self) -> Callable[..., Any] | None:
@ -1755,25 +1765,19 @@ class Script:
if self.top_level:
if self.variables:
try:
variables = self.variables.async_render(
run_variables = self.variables.async_render(
self._hass,
run_variables,
)
except exceptions.TemplateError as err:
self._log("Error rendering variables: %s", err, level=logging.ERROR)
raise
elif run_variables:
variables = dict(run_variables)
else:
variables = {}
variables = ScriptRunVariables.create_top_level(run_variables)
variables["context"] = context
elif self._copy_variables_on_run:
# This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], copy(run_variables))
else:
# This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], run_variables)
# This is not the top level script, run_variables is an instance of ScriptRunVariables
variables = cast(ScriptRunVariables, run_variables)
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
# stop (restart) or wait for (queued) our own script run.
@ -1999,7 +2003,6 @@ class Script:
max_runs=self.max_runs,
logger=self._logger,
top_level=False,
copy_variables=True,
)
parallel_script.change_listener = partial(
self._chain_change_listener, parallel_script

View File

@ -2,8 +2,10 @@
from __future__ import annotations
from collections import ChainMap, UserDict
from collections.abc import Mapping
from typing import Any
from dataclasses import dataclass, field
from typing import Any, cast
from homeassistant.core import HomeAssistant, callback
@ -24,30 +26,23 @@ class ScriptVariables:
hass: HomeAssistant,
run_variables: Mapping[str, Any] | None,
*,
render_as_defaults: bool = True,
limited: bool = False,
) -> dict[str, Any]:
"""Render script variables.
The run variables are used to compute the static variables.
If `render_as_defaults` is True, the run variables will not be overridden.
The run variables are included in the result.
The run variables are used to compute the rendered variable values.
The run variables will not be overridden.
The rendering happens one at a time, with previous results influencing the next.
"""
if self._has_template is None:
self._has_template = template.is_complex(self.variables)
if not self._has_template:
if render_as_defaults:
rendered_variables = dict(self.variables)
rendered_variables = dict(self.variables)
if run_variables is not None:
rendered_variables.update(run_variables)
else:
rendered_variables = (
{} if run_variables is None else dict(run_variables)
)
rendered_variables.update(self.variables)
if run_variables is not None:
rendered_variables.update(run_variables)
return rendered_variables
@ -56,7 +51,7 @@ class ScriptVariables:
for key, value in self.variables.items():
# We can skip if we're going to override this key with
# run variables anyway
if render_as_defaults and key in rendered_variables:
if key in rendered_variables:
continue
rendered_variables[key] = template.render_complex(
@ -65,6 +60,197 @@ class ScriptVariables:
return rendered_variables
@callback
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
"""Render script variables.
Simply renders the variables, the run variables are not included in the result.
The run variables are used to compute the rendered variable values.
The rendering happens one at a time, with previous results influencing the next.
"""
if self._has_template is None:
self._has_template = template.is_complex(self.variables)
if not self._has_template:
return self.variables
run_variables = dict(run_variables)
rendered_variables = {}
for key, value in self.variables.items():
rendered_variable = template.render_complex(value, run_variables)
rendered_variables[key] = rendered_variable
run_variables[key] = rendered_variable
return rendered_variables
def as_dict(self) -> dict[str, Any]:
"""Return dict version of this class."""
return self.variables
@dataclass
class _ParallelData:
"""Data used in each parallel sequence."""
# `protected` is for variables that need special protection in parallel sequences.
# What this means is that such a variable defined in one parallel sequence will not be
# clobbered by the variable with the same name assigned in another parallel sequence.
# It also means that such a variable will not be visible in the outer scope.
# Currently the only such variable is `wait`.
protected: dict[str, Any] = field(default_factory=dict)
# `outer_scope_writes` is for variables that are written to the outer scope from
# a parallel sequence. This is used for generating correct traces of changed variables
# for each of the parallel sequences, isolating them from one another.
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
@dataclass(kw_only=True)
class ScriptRunVariables(UserDict[str, Any]):
"""Class to hold script run variables.
The purpose of this class is to provide proper variable scoping semantics for scripts.
Each instance institutes a new local scope, in which variables can be defined.
Each instance has a reference to the previous instance, except for the top-level instance.
The instances therefore form a chain, in which variable lookup and assignment is performed.
The variables defined lower in the chain naturally override those defined higher up.
"""
# _previous is the previous ScriptRunVariables in the chain
_previous: ScriptRunVariables | None = None
# _parent is the previous non-empty ScriptRunVariables in the chain
_parent: ScriptRunVariables | None = None
# _local_data is the store for local variables
_local_data: dict[str, Any] | None = None
# _parallel_data is used for each parallel sequence
_parallel_data: _ParallelData | None = None
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
_non_parallel_scope: ChainMap[str, Any]
# _full_scope includes all scopes (all the way to the top-level)
_full_scope: ChainMap[str, Any]
@classmethod
def create_top_level(
cls,
initial_data: Mapping[str, Any] | None = None,
) -> ScriptRunVariables:
"""Create a new top-level ScriptRunVariables."""
local_data: dict[str, Any] = {}
non_parallel_scope = full_scope = ChainMap(local_data)
self = cls(
_local_data=local_data,
_non_parallel_scope=non_parallel_scope,
_full_scope=full_scope,
)
if initial_data is not None:
self.update(initial_data)
return self
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
"""Return a new child scope.
:param parallel: Whether the new scope starts a parallel sequence.
"""
if self._local_data is not None or self._parallel_data is not None:
parent = self
else:
parent = cast( # top level always has local data, so we can cast safely
ScriptRunVariables, self._parent
)
parallel_data: _ParallelData | None
if not parallel:
parallel_data = None
non_parallel_scope = self._non_parallel_scope
full_scope = self._full_scope
else:
parallel_data = _ParallelData()
non_parallel_scope = ChainMap(
parallel_data.protected, parallel_data.outer_scope_writes
)
full_scope = self._full_scope.new_child(parallel_data.protected)
return ScriptRunVariables(
_previous=self,
_parent=parent,
_parallel_data=parallel_data,
_non_parallel_scope=non_parallel_scope,
_full_scope=full_scope,
)
def exit_scope(self) -> ScriptRunVariables:
"""Exit the current scope.
Does no clean-up, but simply returns the previous scope.
"""
if self._previous is None:
raise ValueError("Cannot exit top-level scope")
return self._previous
def __delitem__(self, key: str) -> None:
"""Delete a variable (disallowed)."""
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
def __setitem__(self, key: str, value: Any) -> None:
"""Assign value to a variable."""
self._assign(key, value, parallel_protected=False)
def assign_parallel_protected(self, key: str, value: Any) -> None:
"""Assign value to a variable which is to be protected in parallel sequences."""
self._assign(key, value, parallel_protected=True)
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
"""Assign value to a variable.
Value is always assigned to the variable in the nearest scope, in which it is defined.
If the variable is not defined at all, it is created in the top-level scope.
:param parallel_protected: Whether variable is to be protected in parallel sequences.
"""
if self._local_data is not None and key in self._local_data:
self._local_data[key] = value
return
if self._parent is None:
assert self._local_data is not None # top level always has local data
self._local_data[key] = value
return
if self._parallel_data is not None:
if parallel_protected:
self._parallel_data.protected[key] = value
return
self._parallel_data.protected.pop(key, None)
self._parallel_data.outer_scope_writes[key] = value
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
def define_local(self, key: str, value: Any) -> None:
"""Define a local variable and assign value to it."""
if self._local_data is None:
self._local_data = {}
self._non_parallel_scope = self._non_parallel_scope.new_child(
self._local_data
)
self._full_scope = self._full_scope.new_child(self._local_data)
self._local_data[key] = value
@property
def data(self) -> Mapping[str, Any]: # type: ignore[override]
"""Return variables in full scope.
Defined here for UserDict compatibility.
"""
return self._full_scope
@property
def non_parallel_scope(self) -> Mapping[str, Any]:
"""Return variables in non-parallel scope."""
return self._non_parallel_scope
@property
def local_scope(self) -> Mapping[str, Any]:
"""Return variables in local scope."""
return self._local_data if self._local_data is not None else {}

View File

@ -452,6 +452,68 @@ async def test_service_response_data_errors(
await script_obj.async_run(context=context)
async def test_calling_service_response_data_in_scopes(hass: HomeAssistant) -> None:
"""Test response variable is still set after scopes end."""
expected_var = {"data": "value-12345"}
def mock_service(call: ServiceCall) -> ServiceResponse:
"""Mock service call."""
if call.return_response:
return expected_var
return None
hass.services.async_register(
"test", "script", mock_service, supports_response=SupportsResponse.OPTIONAL
)
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{
"alias": "Sequential group",
"sequence": [
{
"alias": "variables",
"variables": {"state": "off"},
},
{
"alias": "service step1",
"action": "test.script",
"response_variable": "my_response",
},
],
}
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
assert result.variables["my_response"] == expected_var
expected_trace = {
"0": [{"variables": {"my_response": expected_var}}],
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
"0/parallel/0/sequence/1": [
{
"result": {
"params": {
"domain": "test",
"service": "script",
"service_data": {},
"target": {},
},
"running_script": False,
},
"variables": {"my_response": expected_var},
}
],
}
assert_action_trace(expected_trace)
async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
"""Test the calling of a service with a data_template with a templated key."""
context = Context()
@ -1706,6 +1768,90 @@ async def test_wait_variables_out(hass: HomeAssistant, mode, action_type) -> Non
assert float(remaining) == 0.0
async def test_wait_in_sequence(hass: HomeAssistant) -> None:
"""Test wait variable is still set after sequence ends."""
sequence = cv.SCRIPT_SCHEMA(
[
{
"alias": "Sequential group",
"sequence": [
{
"alias": "variables",
"variables": {"state": "off"},
},
{
"alias": "wait template",
"wait_template": "{{ state == 'off' }}",
},
],
},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
expected_var = {"completed": True, "remaining": None}
assert result.variables["wait"] == expected_var
expected_trace = {
"0": [{"variables": {"wait": expected_var}}],
"0/sequence/0": [{"variables": {"state": "off"}}],
"0/sequence/1": [
{
"result": {"wait": expected_var},
"variables": {"wait": expected_var},
}
],
}
assert_action_trace(expected_trace)
async def test_wait_in_parallel(hass: HomeAssistant) -> None:
"""Test wait variable is not set after parallel ends."""
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{
"alias": "Sequential group",
"sequence": [
{
"alias": "variables",
"variables": {"state": "off"},
},
{
"alias": "wait template",
"wait_template": "{{ state == 'off' }}",
},
],
}
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
expected_var = {"completed": True, "remaining": None}
assert "wait" not in result.variables
expected_trace = {
"0": [{}],
"0/parallel/0/sequence/0": [{"variables": {"state": "off"}}],
"0/parallel/0/sequence/1": [
{
"result": {"wait": expected_var},
"variables": {"wait": expected_var},
}
],
}
assert_action_trace(expected_trace)
async def test_wait_for_trigger_bad(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:

View File

@ -5,12 +5,13 @@ import pytest
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.script_variables import ScriptRunVariables, ScriptVariables
async def test_static_vars() -> None:
"""Test static vars."""
orig = {"hello": "world"}
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
var = ScriptVariables(orig)
rendered = var.async_render(None, None)
assert rendered is not orig
assert rendered == orig
@ -20,31 +21,28 @@ async def test_static_vars_run_args() -> None:
"""Test static vars."""
orig = {"hello": "world"}
orig_copy = dict(orig)
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
var = ScriptVariables(orig)
rendered = var.async_render(None, {"hello": "override", "run": "var"})
assert rendered == {"hello": "override", "run": "var"}
# Make sure we don't change original vars
assert orig == orig_copy
async def test_static_vars_no_default() -> None:
async def test_static_vars_simple() -> None:
"""Test static vars."""
orig = {"hello": "world"}
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
rendered = var.async_render(None, None, render_as_defaults=False)
assert rendered is not orig
assert rendered == orig
var = ScriptVariables(orig)
rendered = var.async_simple_render({})
assert rendered is orig
async def test_static_vars_run_args_no_default() -> None:
async def test_static_vars_run_args_simple() -> None:
"""Test static vars."""
orig = {"hello": "world"}
orig_copy = dict(orig)
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
rendered = var.async_render(
None, {"hello": "override", "run": "var"}, render_as_defaults=False
)
assert rendered == {"hello": "world", "run": "var"}
var = ScriptVariables(orig)
rendered = var.async_simple_render({"hello": "override", "run": "var"})
assert rendered is orig
# Make sure we don't change original vars
assert orig == orig_copy
@ -78,14 +76,14 @@ async def test_template_vars_run_args(hass: HomeAssistant) -> None:
}
async def test_template_vars_no_default(hass: HomeAssistant) -> None:
async def test_template_vars_simple(hass: HomeAssistant) -> None:
"""Test template vars."""
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
rendered = var.async_render(hass, None, render_as_defaults=False)
rendered = var.async_simple_render({})
assert rendered == {"hello": 2}
async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
async def test_template_vars_run_args_simple(hass: HomeAssistant) -> None:
"""Test template vars."""
var = cv.SCRIPT_VARIABLES_SCHEMA(
{
@ -93,16 +91,13 @@ async def test_template_vars_run_args_no_default(hass: HomeAssistant) -> None:
"something_2": "{{ run_var_ex + 1 }}",
}
)
rendered = var.async_render(
hass,
rendered = var.async_simple_render(
{
"run_var_ex": 5,
"something_2": 1,
},
render_as_defaults=False,
}
)
assert rendered == {
"run_var_ex": 5,
"something": 6,
"something_2": 6,
}
@ -113,3 +108,90 @@ async def test_template_vars_error(hass: HomeAssistant) -> None:
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
with pytest.raises(TemplateError):
var.async_render(hass, None)
async def test_script_vars_exit_top_level() -> None:
"""Test exiting top level script run variables."""
script_vars = ScriptRunVariables.create_top_level()
with pytest.raises(ValueError):
script_vars.exit_scope()
async def test_script_vars_delete_var() -> None:
"""Test deleting from script run variables."""
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 2})
with pytest.raises(TypeError):
del script_vars["x"]
with pytest.raises(TypeError):
script_vars.pop("y")
assert script_vars._full_scope == {"x": 1, "y": 2}
async def test_script_vars_scopes() -> None:
"""Test script run variables scopes."""
script_vars = ScriptRunVariables.create_top_level()
script_vars["x"] = 1
script_vars["y"] = 1
assert script_vars["x"] == 1
assert script_vars["y"] == 1
script_vars_2 = script_vars.enter_scope()
script_vars_2.define_local("x", 2)
assert script_vars_2["x"] == 2
assert script_vars_2["y"] == 1
script_vars_3 = script_vars_2.enter_scope()
script_vars_3["x"] = 3
script_vars_3["y"] = 3
assert script_vars_3["x"] == 3
assert script_vars_3["y"] == 3
script_vars_4 = script_vars_3.enter_scope()
assert script_vars_4["x"] == 3
assert script_vars_4["y"] == 3
assert script_vars_4.exit_scope() is script_vars_3
assert script_vars_3._full_scope == {"x": 3, "y": 3}
assert script_vars_3.local_scope == {}
assert script_vars_3.exit_scope() is script_vars_2
assert script_vars_2._full_scope == {"x": 3, "y": 3}
assert script_vars_2.local_scope == {"x": 3}
assert script_vars_2.exit_scope() is script_vars
assert script_vars._full_scope == {"x": 1, "y": 3}
assert script_vars.local_scope == {"x": 1, "y": 3}
async def test_script_vars_parallel() -> None:
"""Test script run variables parallel support."""
script_vars = ScriptRunVariables.create_top_level({"x": 1, "y": 1, "z": 1})
script_vars_2a = script_vars.enter_scope(parallel=True)
script_vars_3a = script_vars_2a.enter_scope()
script_vars_2b = script_vars.enter_scope(parallel=True)
script_vars_3b = script_vars_2b.enter_scope()
script_vars_3a["x"] = "a"
script_vars_3a.assign_parallel_protected("y", "a")
script_vars_3b["x"] = "b"
script_vars_3b.assign_parallel_protected("y", "b")
assert script_vars_3a._full_scope == {"x": "b", "y": "a", "z": 1}
assert script_vars_3a.non_parallel_scope == {"x": "a", "y": "a"}
assert script_vars_3b._full_scope == {"x": "b", "y": "b", "z": 1}
assert script_vars_3b.non_parallel_scope == {"x": "b", "y": "b"}
assert script_vars_3a.exit_scope() is script_vars_2a
assert script_vars_2a.exit_scope() is script_vars
assert script_vars_3b.exit_scope() is script_vars_2b
assert script_vars_2b.exit_scope() is script_vars
assert script_vars._full_scope == {"x": "b", "y": 1, "z": 1}
assert script_vars.local_scope == {"x": "b", "y": 1, "z": 1}