mirror of
https://github.com/home-assistant/core.git
synced 2025-11-13 04:50:17 +00:00
Fix variable scopes in scripts (#138883)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import ChainMap, UserDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
@@ -24,30 +26,23 @@ class ScriptVariables:
|
||||
hass: HomeAssistant,
|
||||
run_variables: Mapping[str, Any] | None,
|
||||
*,
|
||||
render_as_defaults: bool = True,
|
||||
limited: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
The run variables are used to compute the static variables.
|
||||
|
||||
If `render_as_defaults` is True, the run variables will not be overridden.
|
||||
|
||||
The run variables are included in the result.
|
||||
The run variables are used to compute the rendered variable values.
|
||||
The run variables will not be overridden.
|
||||
The rendering happens one at a time, with previous results influencing the next.
|
||||
"""
|
||||
if self._has_template is None:
|
||||
self._has_template = template.is_complex(self.variables)
|
||||
|
||||
if not self._has_template:
|
||||
if render_as_defaults:
|
||||
rendered_variables = dict(self.variables)
|
||||
rendered_variables = dict(self.variables)
|
||||
|
||||
if run_variables is not None:
|
||||
rendered_variables.update(run_variables)
|
||||
else:
|
||||
rendered_variables = (
|
||||
{} if run_variables is None else dict(run_variables)
|
||||
)
|
||||
rendered_variables.update(self.variables)
|
||||
if run_variables is not None:
|
||||
rendered_variables.update(run_variables)
|
||||
|
||||
return rendered_variables
|
||||
|
||||
@@ -56,7 +51,7 @@ class ScriptVariables:
|
||||
for key, value in self.variables.items():
|
||||
# We can skip if we're going to override this key with
|
||||
# run variables anyway
|
||||
if render_as_defaults and key in rendered_variables:
|
||||
if key in rendered_variables:
|
||||
continue
|
||||
|
||||
rendered_variables[key] = template.render_complex(
|
||||
@@ -65,6 +60,197 @@ class ScriptVariables:
|
||||
|
||||
return rendered_variables
|
||||
|
||||
@callback
|
||||
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
Simply renders the variables, the run variables are not included in the result.
|
||||
The run variables are used to compute the rendered variable values.
|
||||
The rendering happens one at a time, with previous results influencing the next.
|
||||
"""
|
||||
if self._has_template is None:
|
||||
self._has_template = template.is_complex(self.variables)
|
||||
|
||||
if not self._has_template:
|
||||
return self.variables
|
||||
|
||||
run_variables = dict(run_variables)
|
||||
rendered_variables = {}
|
||||
|
||||
for key, value in self.variables.items():
|
||||
rendered_variable = template.render_complex(value, run_variables)
|
||||
rendered_variables[key] = rendered_variable
|
||||
run_variables[key] = rendered_variable
|
||||
|
||||
return rendered_variables
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return dict version of this class."""
|
||||
return self.variables
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ParallelData:
|
||||
"""Data used in each parallel sequence."""
|
||||
|
||||
# `protected` is for variables that need special protection in parallel sequences.
|
||||
# What this means is that such a variable defined in one parallel sequence will not be
|
||||
# clobbered by the variable with the same name assigned in another parallel sequence.
|
||||
# It also means that such a variable will not be visible in the outer scope.
|
||||
# Currently the only such variable is `wait`.
|
||||
protected: dict[str, Any] = field(default_factory=dict)
|
||||
# `outer_scope_writes` is for variables that are written to the outer scope from
|
||||
# a parallel sequence. This is used for generating correct traces of changed variables
|
||||
# for each of the parallel sequences, isolating them from one another.
|
||||
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ScriptRunVariables(UserDict[str, Any]):
|
||||
"""Class to hold script run variables.
|
||||
|
||||
The purpose of this class is to provide proper variable scoping semantics for scripts.
|
||||
Each instance institutes a new local scope, in which variables can be defined.
|
||||
Each instance has a reference to the previous instance, except for the top-level instance.
|
||||
The instances therefore form a chain, in which variable lookup and assignment is performed.
|
||||
The variables defined lower in the chain naturally override those defined higher up.
|
||||
"""
|
||||
|
||||
# _previous is the previous ScriptRunVariables in the chain
|
||||
_previous: ScriptRunVariables | None = None
|
||||
# _parent is the previous non-empty ScriptRunVariables in the chain
|
||||
_parent: ScriptRunVariables | None = None
|
||||
|
||||
# _local_data is the store for local variables
|
||||
_local_data: dict[str, Any] | None = None
|
||||
# _parallel_data is used for each parallel sequence
|
||||
_parallel_data: _ParallelData | None = None
|
||||
|
||||
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
|
||||
_non_parallel_scope: ChainMap[str, Any]
|
||||
# _full_scope includes all scopes (all the way to the top-level)
|
||||
_full_scope: ChainMap[str, Any]
|
||||
|
||||
@classmethod
|
||||
def create_top_level(
|
||||
cls,
|
||||
initial_data: Mapping[str, Any] | None = None,
|
||||
) -> ScriptRunVariables:
|
||||
"""Create a new top-level ScriptRunVariables."""
|
||||
local_data: dict[str, Any] = {}
|
||||
non_parallel_scope = full_scope = ChainMap(local_data)
|
||||
self = cls(
|
||||
_local_data=local_data,
|
||||
_non_parallel_scope=non_parallel_scope,
|
||||
_full_scope=full_scope,
|
||||
)
|
||||
if initial_data is not None:
|
||||
self.update(initial_data)
|
||||
return self
|
||||
|
||||
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
|
||||
"""Return a new child scope.
|
||||
|
||||
:param parallel: Whether the new scope starts a parallel sequence.
|
||||
"""
|
||||
if self._local_data is not None or self._parallel_data is not None:
|
||||
parent = self
|
||||
else:
|
||||
parent = cast( # top level always has local data, so we can cast safely
|
||||
ScriptRunVariables, self._parent
|
||||
)
|
||||
|
||||
parallel_data: _ParallelData | None
|
||||
if not parallel:
|
||||
parallel_data = None
|
||||
non_parallel_scope = self._non_parallel_scope
|
||||
full_scope = self._full_scope
|
||||
else:
|
||||
parallel_data = _ParallelData()
|
||||
non_parallel_scope = ChainMap(
|
||||
parallel_data.protected, parallel_data.outer_scope_writes
|
||||
)
|
||||
full_scope = self._full_scope.new_child(parallel_data.protected)
|
||||
|
||||
return ScriptRunVariables(
|
||||
_previous=self,
|
||||
_parent=parent,
|
||||
_parallel_data=parallel_data,
|
||||
_non_parallel_scope=non_parallel_scope,
|
||||
_full_scope=full_scope,
|
||||
)
|
||||
|
||||
def exit_scope(self) -> ScriptRunVariables:
|
||||
"""Exit the current scope.
|
||||
|
||||
Does no clean-up, but simply returns the previous scope.
|
||||
"""
|
||||
if self._previous is None:
|
||||
raise ValueError("Cannot exit top-level scope")
|
||||
return self._previous
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Delete a variable (disallowed)."""
|
||||
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""Assign value to a variable."""
|
||||
self._assign(key, value, parallel_protected=False)
|
||||
|
||||
def assign_parallel_protected(self, key: str, value: Any) -> None:
|
||||
"""Assign value to a variable which is to be protected in parallel sequences."""
|
||||
self._assign(key, value, parallel_protected=True)
|
||||
|
||||
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
|
||||
"""Assign value to a variable.
|
||||
|
||||
Value is always assigned to the variable in the nearest scope, in which it is defined.
|
||||
If the variable is not defined at all, it is created in the top-level scope.
|
||||
|
||||
:param parallel_protected: Whether variable is to be protected in parallel sequences.
|
||||
"""
|
||||
if self._local_data is not None and key in self._local_data:
|
||||
self._local_data[key] = value
|
||||
return
|
||||
|
||||
if self._parent is None:
|
||||
assert self._local_data is not None # top level always has local data
|
||||
self._local_data[key] = value
|
||||
return
|
||||
|
||||
if self._parallel_data is not None:
|
||||
if parallel_protected:
|
||||
self._parallel_data.protected[key] = value
|
||||
return
|
||||
self._parallel_data.protected.pop(key, None)
|
||||
self._parallel_data.outer_scope_writes[key] = value
|
||||
|
||||
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
|
||||
|
||||
def define_local(self, key: str, value: Any) -> None:
|
||||
"""Define a local variable and assign value to it."""
|
||||
if self._local_data is None:
|
||||
self._local_data = {}
|
||||
self._non_parallel_scope = self._non_parallel_scope.new_child(
|
||||
self._local_data
|
||||
)
|
||||
self._full_scope = self._full_scope.new_child(self._local_data)
|
||||
self._local_data[key] = value
|
||||
|
||||
@property
|
||||
def data(self) -> Mapping[str, Any]: # type: ignore[override]
|
||||
"""Return variables in full scope.
|
||||
|
||||
Defined here for UserDict compatibility.
|
||||
"""
|
||||
return self._full_scope
|
||||
|
||||
@property
|
||||
def non_parallel_scope(self) -> Mapping[str, Any]:
|
||||
"""Return variables in non-parallel scope."""
|
||||
return self._non_parallel_scope
|
||||
|
||||
@property
|
||||
def local_scope(self) -> Mapping[str, Any]:
|
||||
"""Return variables in local scope."""
|
||||
return self._local_data if self._local_data is not None else {}
|
||||
|
||||
Reference in New Issue
Block a user