mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 02:07:54 +00:00
257 lines
9.5 KiB
Python
257 lines
9.5 KiB
Python
"""Script variables."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import ChainMap, UserDict
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, cast
|
|
|
|
from homeassistant.core import HomeAssistant, callback
|
|
|
|
from . import template
|
|
|
|
|
|
class ScriptVariables:
|
|
"""Class to hold and render script variables."""
|
|
|
|
def __init__(self, variables: dict[str, Any]) -> None:
|
|
"""Initialize script variables."""
|
|
self.variables = variables
|
|
self._has_template: bool | None = None
|
|
|
|
@callback
|
|
def async_render(
|
|
self,
|
|
hass: HomeAssistant,
|
|
run_variables: Mapping[str, Any] | None,
|
|
*,
|
|
limited: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""Render script variables.
|
|
|
|
The run variables are included in the result.
|
|
The run variables are used to compute the rendered variable values.
|
|
The run variables will not be overridden.
|
|
The rendering happens one at a time, with previous results influencing the next.
|
|
"""
|
|
if self._has_template is None:
|
|
self._has_template = template.is_complex(self.variables)
|
|
|
|
if not self._has_template:
|
|
rendered_variables = dict(self.variables)
|
|
|
|
if run_variables is not None:
|
|
rendered_variables.update(run_variables)
|
|
|
|
return rendered_variables
|
|
|
|
rendered_variables = {} if run_variables is None else dict(run_variables)
|
|
|
|
for key, value in self.variables.items():
|
|
# We can skip if we're going to override this key with
|
|
# run variables anyway
|
|
if key in rendered_variables:
|
|
continue
|
|
|
|
rendered_variables[key] = template.render_complex(
|
|
value, rendered_variables, limited
|
|
)
|
|
|
|
return rendered_variables
|
|
|
|
@callback
|
|
def async_simple_render(self, run_variables: Mapping[str, Any]) -> dict[str, Any]:
|
|
"""Render script variables.
|
|
|
|
Simply renders the variables, the run variables are not included in the result.
|
|
The run variables are used to compute the rendered variable values.
|
|
The rendering happens one at a time, with previous results influencing the next.
|
|
"""
|
|
if self._has_template is None:
|
|
self._has_template = template.is_complex(self.variables)
|
|
|
|
if not self._has_template:
|
|
return self.variables
|
|
|
|
run_variables = dict(run_variables)
|
|
rendered_variables = {}
|
|
|
|
for key, value in self.variables.items():
|
|
rendered_variable = template.render_complex(value, run_variables)
|
|
rendered_variables[key] = rendered_variable
|
|
run_variables[key] = rendered_variable
|
|
|
|
return rendered_variables
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
"""Return dict version of this class."""
|
|
return self.variables
|
|
|
|
|
|
@dataclass
|
|
class _ParallelData:
|
|
"""Data used in each parallel sequence."""
|
|
|
|
# `protected` is for variables that need special protection in parallel sequences.
|
|
# What this means is that such a variable defined in one parallel sequence will not be
|
|
# clobbered by the variable with the same name assigned in another parallel sequence.
|
|
# It also means that such a variable will not be visible in the outer scope.
|
|
# Currently the only such variable is `wait`.
|
|
protected: dict[str, Any] = field(default_factory=dict)
|
|
# `outer_scope_writes` is for variables that are written to the outer scope from
|
|
# a parallel sequence. This is used for generating correct traces of changed variables
|
|
# for each of the parallel sequences, isolating them from one another.
|
|
outer_scope_writes: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ScriptRunVariables(UserDict[str, Any]):
|
|
"""Class to hold script run variables.
|
|
|
|
The purpose of this class is to provide proper variable scoping semantics for scripts.
|
|
Each instance institutes a new local scope, in which variables can be defined.
|
|
Each instance has a reference to the previous instance, except for the top-level instance.
|
|
The instances therefore form a chain, in which variable lookup and assignment is performed.
|
|
The variables defined lower in the chain naturally override those defined higher up.
|
|
"""
|
|
|
|
# _previous is the previous ScriptRunVariables in the chain
|
|
_previous: ScriptRunVariables | None = None
|
|
# _parent is the previous non-empty ScriptRunVariables in the chain
|
|
_parent: ScriptRunVariables | None = None
|
|
|
|
# _local_data is the store for local variables
|
|
_local_data: dict[str, Any] | None = None
|
|
# _parallel_data is used for each parallel sequence
|
|
_parallel_data: _ParallelData | None = None
|
|
|
|
# _non_parallel_scope includes all scopes all the way to the most recent parallel split
|
|
_non_parallel_scope: ChainMap[str, Any]
|
|
# _full_scope includes all scopes (all the way to the top-level)
|
|
_full_scope: ChainMap[str, Any]
|
|
|
|
@classmethod
|
|
def create_top_level(
|
|
cls,
|
|
initial_data: Mapping[str, Any] | None = None,
|
|
) -> ScriptRunVariables:
|
|
"""Create a new top-level ScriptRunVariables."""
|
|
local_data: dict[str, Any] = {}
|
|
non_parallel_scope = full_scope = ChainMap(local_data)
|
|
self = cls(
|
|
_local_data=local_data,
|
|
_non_parallel_scope=non_parallel_scope,
|
|
_full_scope=full_scope,
|
|
)
|
|
if initial_data is not None:
|
|
self.update(initial_data)
|
|
return self
|
|
|
|
def enter_scope(self, *, parallel: bool = False) -> ScriptRunVariables:
|
|
"""Return a new child scope.
|
|
|
|
:param parallel: Whether the new scope starts a parallel sequence.
|
|
"""
|
|
if self._local_data is not None or self._parallel_data is not None:
|
|
parent = self
|
|
else:
|
|
parent = cast( # top level always has local data, so we can cast safely
|
|
ScriptRunVariables, self._parent
|
|
)
|
|
|
|
parallel_data: _ParallelData | None
|
|
if not parallel:
|
|
parallel_data = None
|
|
non_parallel_scope = self._non_parallel_scope
|
|
full_scope = self._full_scope
|
|
else:
|
|
parallel_data = _ParallelData()
|
|
non_parallel_scope = ChainMap(
|
|
parallel_data.protected, parallel_data.outer_scope_writes
|
|
)
|
|
full_scope = self._full_scope.new_child(parallel_data.protected)
|
|
|
|
return ScriptRunVariables(
|
|
_previous=self,
|
|
_parent=parent,
|
|
_parallel_data=parallel_data,
|
|
_non_parallel_scope=non_parallel_scope,
|
|
_full_scope=full_scope,
|
|
)
|
|
|
|
def exit_scope(self) -> ScriptRunVariables:
|
|
"""Exit the current scope.
|
|
|
|
Does no clean-up, but simply returns the previous scope.
|
|
"""
|
|
if self._previous is None:
|
|
raise ValueError("Cannot exit top-level scope")
|
|
return self._previous
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
"""Delete a variable (disallowed)."""
|
|
raise TypeError("Deleting items is not allowed in ScriptRunVariables.")
|
|
|
|
def __setitem__(self, key: str, value: Any) -> None:
|
|
"""Assign value to a variable."""
|
|
self._assign(key, value, parallel_protected=False)
|
|
|
|
def assign_parallel_protected(self, key: str, value: Any) -> None:
|
|
"""Assign value to a variable which is to be protected in parallel sequences."""
|
|
self._assign(key, value, parallel_protected=True)
|
|
|
|
def _assign(self, key: str, value: Any, *, parallel_protected: bool) -> None:
|
|
"""Assign value to a variable.
|
|
|
|
Value is always assigned to the variable in the nearest scope, in which it is defined.
|
|
If the variable is not defined at all, it is created in the top-level scope.
|
|
|
|
:param parallel_protected: Whether variable is to be protected in parallel sequences.
|
|
"""
|
|
if self._local_data is not None and key in self._local_data:
|
|
self._local_data[key] = value
|
|
return
|
|
|
|
if self._parent is None:
|
|
assert self._local_data is not None # top level always has local data
|
|
self._local_data[key] = value
|
|
return
|
|
|
|
if self._parallel_data is not None:
|
|
if parallel_protected:
|
|
self._parallel_data.protected[key] = value
|
|
return
|
|
self._parallel_data.protected.pop(key, None)
|
|
self._parallel_data.outer_scope_writes[key] = value
|
|
|
|
self._parent._assign(key, value, parallel_protected=parallel_protected) # noqa: SLF001
|
|
|
|
def define_local(self, key: str, value: Any) -> None:
|
|
"""Define a local variable and assign value to it."""
|
|
if self._local_data is None:
|
|
self._local_data = {}
|
|
self._non_parallel_scope = self._non_parallel_scope.new_child(
|
|
self._local_data
|
|
)
|
|
self._full_scope = self._full_scope.new_child(self._local_data)
|
|
self._local_data[key] = value
|
|
|
|
@property
|
|
def data(self) -> Mapping[str, Any]: # type: ignore[override]
|
|
"""Return variables in full scope.
|
|
|
|
Defined here for UserDict compatibility.
|
|
"""
|
|
return self._full_scope
|
|
|
|
@property
|
|
def non_parallel_scope(self) -> Mapping[str, Any]:
|
|
"""Return variables in non-parallel scope."""
|
|
return self._non_parallel_scope
|
|
|
|
@property
|
|
def local_scope(self) -> Mapping[str, Any]:
|
|
"""Return variables in local scope."""
|
|
return self._local_data if self._local_data is not None else {}
|