core/homeassistant/helpers/script_variables.py
Artur Pragacz b964bc58be
Fix variable scopes in scripts (#138883)
Co-authored-by: Erik <erik@montnemery.com>
2025-02-26 16:19:19 +01:00

257 lines
9.5 KiB
Python

"""Script variables."""
from __future__ import annotations
from collections import ChainMap, UserDict
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any, cast
from homeassistant.core import HomeAssistant, callback
from . import template
class ScriptVariables:
"""Class to hold and render script variables."""
def __init__(self, variables: dict[str, Any]) -> None:
"""Initialize script variables."""
self.variables = variables
self._has_template: bool | None = None
@callback
def async_render(
self,
hass: HomeAssistant,
run_variables: Mapping[str, Any] | None,
*,
limited: bool = False,
) -> dict[str, Any]:
"""Render script variables.
The run variables are included in the result.
The run variables are used to compute the rendered variable values.
The run variables will not be overridden.
The rendering happens one at a time, with previous results influencing the next.
"""
if self._has_template is None:
self._has_template = template.is_complex(self.variables)
if not self._has_template:
rendered_variables = dict(self.variables)
if run_variables is not None:
rendered_variables.update(run_variables)
return rendered_variables
rendered_variables = {} if run_variables is None else dict(run_variables)
for key, value in self.variables.items():
# We can skip if we're going to override this key with
# run variables anyway
if key in rendered_variables:
continue
rendered_variables[key] = template.render_complex(
value, rendered_variables, limited
)
return rendered_variables
@callback
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
"""Render script variables.
Simply renders the variables, the run variables are not included in the result.
The run variables are used to compute the rendered variable values.
The rendering happens one at a time, with previous results influencing the next.
"""
if self._has_template is None:
self._has_template = template.is_complex(self.variables)
if not self._has_template:
return self.variables
run_variables = dict(run_variables)
rendered_variables = {}
for key, value in self.variables.items():
rendered_variable = template.render_complex(value, run_variables)
rendered_variables[key] = rendered_variable
run_variables[key] = rendered_variable
return rendered_variables
def as_dict(self) -> dict[str, Any]:
"""Return dict version of this class."""
return self.variables
@dataclass
class _ParallelData:
"""Data used in each parallel sequence."""
# `protected` is for variables that need special protection in parallel sequences.
# What this means is that such a variable defined in one parallel sequence will not be
# clobbered by the variable with the same name assigned in another parallel sequence.
# It also means that such a variable will not be visible in the outer scope.
# Currently the only such variable is `wait`.
protected: dict[str, Any] = field(default_factory=dict)
# `outer_scope_writes` is for variables that are written to the outer scope from
# a parallel sequence. This is used for generating correct traces of changed variables
# for each of the parallel sequences, isolating them from one another.
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
@dataclass(kw_only=True)
class ScriptRunVariables(UserDict[str, Any]):
"""Class to hold script run variables.
The purpose of this class is to provide proper variable scoping semantics for scripts.
Each instance institutes a new local scope, in which variables can be defined.
Each instance has a reference to the previous instance, except for the top-level instance.
The instances therefore form a chain, in which variable lookup and assignment is performed.
The variables defined lower in the chain naturally override those defined higher up.
"""
# _previous is the previous ScriptRunVariables in the chain
_previous: ScriptRunVariables | None = None
# _parent is the previous non-empty ScriptRunVariables in the chain
_parent: ScriptRunVariables | None = None
# _local_data is the store for local variables
_local_data: dict[str, Any] | None = None
# _parallel_data is used for each parallel sequence
_parallel_data: _ParallelData | None = None
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
_non_parallel_scope: ChainMap[str, Any]
# _full_scope includes all scopes (all the way to the top-level)
_full_scope: ChainMap[str, Any]
@classmethod
def create_top_level(
cls,
initial_data: Mapping[str, Any] | None = None,
) -> ScriptRunVariables:
"""Create a new top-level ScriptRunVariables."""
local_data: dict[str, Any] = {}
non_parallel_scope = full_scope = ChainMap(local_data)
self = cls(
_local_data=local_data,
_non_parallel_scope=non_parallel_scope,
_full_scope=full_scope,
)
if initial_data is not None:
self.update(initial_data)
return self
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
"""Return a new child scope.
:param parallel: Whether the new scope starts a parallel sequence.
"""
if self._local_data is not None or self._parallel_data is not None:
parent = self
else:
parent = cast( # top level always has local data, so we can cast safely
ScriptRunVariables, self._parent
)
parallel_data: _ParallelData | None
if not parallel:
parallel_data = None
non_parallel_scope = self._non_parallel_scope
full_scope = self._full_scope
else:
parallel_data = _ParallelData()
non_parallel_scope = ChainMap(
parallel_data.protected, parallel_data.outer_scope_writes
)
full_scope = self._full_scope.new_child(parallel_data.protected)
return ScriptRunVariables(
_previous=self,
_parent=parent,
_parallel_data=parallel_data,
_non_parallel_scope=non_parallel_scope,
_full_scope=full_scope,
)
def exit_scope(self) -> ScriptRunVariables:
"""Exit the current scope.
Does no clean-up, but simply returns the previous scope.
"""
if self._previous is None:
raise ValueError("Cannot exit top-level scope")
return self._previous
def __delitem__(self, key: str) -> None:
"""Delete a variable (disallowed)."""
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
def __setitem__(self, key: str, value: Any) -> None:
"""Assign value to a variable."""
self._assign(key, value, parallel_protected=False)
def assign_parallel_protected(self, key: str, value: Any) -> None:
"""Assign value to a variable which is to be protected in parallel sequences."""
self._assign(key, value, parallel_protected=True)
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
"""Assign value to a variable.
Value is always assigned to the variable in the nearest scope, in which it is defined.
If the variable is not defined at all, it is created in the top-level scope.
:param parallel_protected: Whether variable is to be protected in parallel sequences.
"""
if self._local_data is not None and key in self._local_data:
self._local_data[key] = value
return
if self._parent is None:
assert self._local_data is not None # top level always has local data
self._local_data[key] = value
return
if self._parallel_data is not None:
if parallel_protected:
self._parallel_data.protected[key] = value
return
self._parallel_data.protected.pop(key, None)
self._parallel_data.outer_scope_writes[key] = value
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
def define_local(self, key: str, value: Any) -> None:
"""Define a local variable and assign value to it."""
if self._local_data is None:
self._local_data = {}
self._non_parallel_scope = self._non_parallel_scope.new_child(
self._local_data
)
self._full_scope = self._full_scope.new_child(self._local_data)
self._local_data[key] = value
@property
def data(self) -> Mapping[str, Any]: # type: ignore[override]
"""Return variables in full scope.
Defined here for UserDict compatibility.
"""
return self._full_scope
@property
def non_parallel_scope(self) -> Mapping[str, Any]:
"""Return variables in non-parallel scope."""
return self._non_parallel_scope
@property
def local_scope(self) -> Mapping[str, Any]:
"""Return variables in local scope."""
return self._local_data if self._local_data is not None else {}