"""Script variables.""" from __future__ import annotations from collections import ChainMap, UserDict from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any, cast from homeassistant.core import HomeAssistant, callback from . import template class ScriptVariables: """Class to hold and render script variables.""" def __init__(self, variables: dict[str, Any]) -> None: """Initialize script variables.""" self.variables = variables self._has_template: bool | None = None @callback def async_render( self, hass: HomeAssistant, run_variables: Mapping[str, Any] | None, *, limited: bool = False, ) -> dict[str, Any]: """Render script variables. The run variables are included in the result. The run variables are used to compute the rendered variable values. The run variables will not be overridden. The rendering happens one at a time, with previous results influencing the next. """ if self._has_template is None: self._has_template = template.is_complex(self.variables) if not self._has_template: rendered_variables = dict(self.variables) if run_variables is not None: rendered_variables.update(run_variables) return rendered_variables rendered_variables = {} if run_variables is None else dict(run_variables) for key, value in self.variables.items(): # We can skip if we're going to override this key with # run variables anyway if key in rendered_variables: continue rendered_variables[key] = template.render_complex( value, rendered_variables, limited ) return rendered_variables @callback def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]: """Render script variables. Simply renders the variables, the run variables are not included in the result. The run variables are used to compute the rendered variable values. The rendering happens one at a time, with previous results influencing the next. """ if self._has_template is None: self._has_template = template.is_complex(self.variables) if not self._has_template: return self.variables run_variables = dict(run_variables) rendered_variables = {} for key, value in self.variables.items(): rendered_variable = template.render_complex(value, run_variables) rendered_variables[key] = rendered_variable run_variables[key] = rendered_variable return rendered_variables def as_dict(self) -> dict[str, Any]: """Return dict version of this class.""" return self.variables @dataclass class _ParallelData: """Data used in each parallel sequence.""" # `protected` is for variables that need special protection in parallel sequences. # What this means is that such a variable defined in one parallel sequence will not be # clobbered by the variable with the same name assigned in another parallel sequence. # It also means that such a variable will not be visible in the outer scope. # Currently the only such variable is `wait`. protected: dict[str, Any] = field(default_factory=dict) # `outer_scope_writes` is for variables that are written to the outer scope from # a parallel sequence. This is used for generating correct traces of changed variables # for each of the parallel sequences, isolating them from one another. outer_scope_writes: dict[str, Any] = field(default_factory=dict) @dataclass(kw_only=True) class ScriptRunVariables(UserDict[str, Any]): """Class to hold script run variables. The purpose of this class is to provide proper variable scoping semantics for scripts. Each instance institutes a new local scope, in which variables can be defined. Each instance has a reference to the previous instance, except for the top-level instance. The instances therefore form a chain, in which variable lookup and assignment is performed. The variables defined lower in the chain naturally override those defined higher up. """ # _previous is the previous ScriptRunVariables in the chain _previous: ScriptRunVariables | None = None # _parent is the previous non-empty ScriptRunVariables in the chain _parent: ScriptRunVariables | None = None # _local_data is the store for local variables _local_data: dict[str, Any] | None = None # _parallel_data is used for each parallel sequence _parallel_data: _ParallelData | None = None # _non_parallel_scope includes all scopes all the way to the most recent parallel split _non_parallel_scope: ChainMap[str, Any] # _full_scope includes all scopes (all the way to the top-level) _full_scope: ChainMap[str, Any] @classmethod def create_top_level( cls, initial_data: Mapping[str, Any] | None = None, ) -> ScriptRunVariables: """Create a new top-level ScriptRunVariables.""" local_data: dict[str, Any] = {} non_parallel_scope = full_scope = ChainMap(local_data) self = cls( _local_data=local_data, _non_parallel_scope=non_parallel_scope, _full_scope=full_scope, ) if initial_data is not None: self.update(initial_data) return self def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables: """Return a new child scope. :param parallel: Whether the new scope starts a parallel sequence. """ if self._local_data is not None or self._parallel_data is not None: parent = self else: parent = cast( # top level always has local data, so we can cast safely ScriptRunVariables, self._parent ) parallel_data: _ParallelData | None if not parallel: parallel_data = None non_parallel_scope = self._non_parallel_scope full_scope = self._full_scope else: parallel_data = _ParallelData() non_parallel_scope = ChainMap( parallel_data.protected, parallel_data.outer_scope_writes ) full_scope = self._full_scope.new_child(parallel_data.protected) return ScriptRunVariables( _previous=self, _parent=parent, _parallel_data=parallel_data, _non_parallel_scope=non_parallel_scope, _full_scope=full_scope, ) def exit_scope(self) -> ScriptRunVariables: """Exit the current scope. Does no clean-up, but simply returns the previous scope. """ if self._previous is None: raise ValueError("Cannot exit top-level scope") return self._previous def __delitem__(self, key: str) -> None: """Delete a variable (disallowed).""" raise TypeError("Deleting items is not allowed in ScriptRunVariables.") def __setitem__(self, key: str, value: Any) -> None: """Assign value to a variable.""" self._assign(key, value, parallel_protected=False) def assign_parallel_protected(self, key: str, value: Any) -> None: """Assign value to a variable which is to be protected in parallel sequences.""" self._assign(key, value, parallel_protected=True) def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None: """Assign value to a variable. Value is always assigned to the variable in the nearest scope, in which it is defined. If the variable is not defined at all, it is created in the top-level scope. :param parallel_protected: Whether variable is to be protected in parallel sequences. """ if self._local_data is not None and key in self._local_data: self._local_data[key] = value return if self._parent is None: assert self._local_data is not None # top level always has local data self._local_data[key] = value return if self._parallel_data is not None: if parallel_protected: self._parallel_data.protected[key] = value return self._parallel_data.protected.pop(key, None) self._parallel_data.outer_scope_writes[key] = value self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001 def define_local(self, key: str, value: Any) -> None: """Define a local variable and assign value to it.""" if self._local_data is None: self._local_data = {} self._non_parallel_scope = self._non_parallel_scope.new_child( self._local_data ) self._full_scope = self._full_scope.new_child(self._local_data) self._local_data[key] = value @property def data(self) -> Mapping[str, Any]: # type: ignore[override] """Return variables in full scope. Defined here for UserDict compatibility. """ return self._full_scope @property def non_parallel_scope(self) -> Mapping[str, Any]: """Return variables in non-parallel scope.""" return self._non_parallel_scope @property def local_scope(self) -> Mapping[str, Any]: """Return variables in local scope.""" return self._local_data if self._local_data is not None else {}