mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 22:07:10 +00:00
Extract variable rendering (#39934)
This commit is contained in:
parent
b107e87d38
commit
8ef04268be
@ -45,6 +45,7 @@ from homeassistant.helpers.script import (
|
||||
Script,
|
||||
make_script_schema,
|
||||
)
|
||||
from homeassistant.helpers.script_variables import ScriptVariables
|
||||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.helpers.trigger import async_initialize_triggers
|
||||
from homeassistant.helpers.typing import TemplateVarsType
|
||||
@ -256,8 +257,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
self._logger = _LOGGER
|
||||
self._variables = variables
|
||||
self._variables_dynamic = template.is_complex(variables)
|
||||
self._variables: ScriptVariables = variables
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -334,9 +334,6 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
"""Startup with initial state or previous state."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
if self._variables_dynamic:
|
||||
template.attach(cast(HomeAssistant, self.hass), self._variables)
|
||||
|
||||
self._logger = logging.getLogger(
|
||||
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
|
||||
)
|
||||
@ -392,15 +389,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if self._variables:
|
||||
if self._variables_dynamic:
|
||||
variables = template.render_complex(self._variables, run_variables)
|
||||
try:
|
||||
variables = self._variables.async_render(self.hass, run_variables)
|
||||
except template.TemplateError as err:
|
||||
self._logger.error("Error rendering variables: %s", err)
|
||||
return
|
||||
else:
|
||||
variables = dict(self._variables)
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
if run_variables:
|
||||
variables.update(run_variables)
|
||||
variables = run_variables
|
||||
|
||||
if (
|
||||
not skip_condition
|
||||
|
@ -77,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
|
||||
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
||||
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
|
||||
{vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
|
||||
{vol.Optional(ATTR_VARIABLES): {str: cv.match_all}}
|
||||
)
|
||||
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
||||
|
||||
|
@ -81,7 +81,10 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import split_entity_id, valid_entity_id
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import template as template_helper
|
||||
from homeassistant.helpers import (
|
||||
script_variables as script_variables_helper,
|
||||
template as template_helper,
|
||||
)
|
||||
from homeassistant.helpers.logging import KeywordStyleAdapter
|
||||
from homeassistant.util import slugify as util_slugify
|
||||
import homeassistant.util.dt as dt_util
|
||||
@ -863,7 +866,11 @@ def make_entity_service_schema(
|
||||
)
|
||||
|
||||
|
||||
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
|
||||
SCRIPT_VARIABLES_SCHEMA = vol.All(
|
||||
vol.Schema({str: template_complex}),
|
||||
# pylint: disable=unnecessary-lambda
|
||||
lambda val: script_variables_helper.ScriptVariables(val),
|
||||
)
|
||||
|
||||
|
||||
def script_action(value: Any) -> dict:
|
||||
|
@ -55,6 +55,7 @@ from homeassistant.const import (
|
||||
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
|
||||
from homeassistant.helpers import condition, config_validation as cv, template
|
||||
from homeassistant.helpers.event import async_call_later, async_track_template
|
||||
from homeassistant.helpers.script_variables import ScriptVariables
|
||||
from homeassistant.helpers.service import (
|
||||
CONF_SERVICE_DATA,
|
||||
async_prepare_call_from_config,
|
||||
@ -717,7 +718,7 @@ class Script:
|
||||
logger: Optional[logging.Logger] = None,
|
||||
log_exceptions: bool = True,
|
||||
top_level: bool = True,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
variables: Optional[ScriptVariables] = None,
|
||||
) -> None:
|
||||
"""Initialize the script."""
|
||||
all_scripts = hass.data.get(DATA_SCRIPTS)
|
||||
@ -900,15 +901,19 @@ class Script:
|
||||
# during the run back to the caller.
|
||||
if self._top_level:
|
||||
if self.variables:
|
||||
if self._variables_dynamic:
|
||||
variables = template.render_complex(self.variables, run_variables)
|
||||
else:
|
||||
variables = dict(self.variables)
|
||||
try:
|
||||
variables = self.variables.async_render(
|
||||
self._hass,
|
||||
run_variables,
|
||||
)
|
||||
except template.TemplateError as err:
|
||||
self._log("Error rendering variables: %s", err, level=logging.ERROR)
|
||||
raise
|
||||
elif run_variables:
|
||||
variables = dict(run_variables)
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
if run_variables:
|
||||
variables.update(run_variables)
|
||||
variables["context"] = context
|
||||
else:
|
||||
variables = cast(dict, run_variables)
|
||||
|
57
homeassistant/helpers/script_variables.py
Normal file
57
homeassistant/helpers/script_variables.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Script variables."""
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from . import template
|
||||
|
||||
|
||||
class ScriptVariables:
|
||||
"""Class to hold and render script variables."""
|
||||
|
||||
def __init__(self, variables: Dict[str, Any]):
|
||||
"""Initialize script variables."""
|
||||
self.variables = variables
|
||||
self._has_template: Optional[bool] = None
|
||||
|
||||
@callback
|
||||
def async_render(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
run_variables: Optional[Mapping[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
The run variables are used to compute the static variables, but afterwards will also
|
||||
be merged on top of the static variables.
|
||||
"""
|
||||
if self._has_template is None:
|
||||
self._has_template = template.is_complex(self.variables)
|
||||
template.attach(hass, self.variables)
|
||||
|
||||
if not self._has_template:
|
||||
rendered_variables = dict(self.variables)
|
||||
|
||||
if run_variables is not None:
|
||||
rendered_variables.update(run_variables)
|
||||
|
||||
return rendered_variables
|
||||
|
||||
rendered_variables = {} if run_variables is None else dict(run_variables)
|
||||
|
||||
for key, value in self.variables.items():
|
||||
# We can skip if we're going to override this key with
|
||||
# run variables anyway
|
||||
if key in rendered_variables:
|
||||
continue
|
||||
|
||||
rendered_variables[key] = template.render_complex(value, rendered_variables)
|
||||
|
||||
if run_variables:
|
||||
rendered_variables.update(run_variables)
|
||||
|
||||
return rendered_variables
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
"""Return dict version of this class."""
|
||||
return self.variables
|
@ -1136,7 +1136,7 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
||||
assert event2["entity_id"] == "automation.bye"
|
||||
|
||||
|
||||
async def test_automation_variables(hass):
|
||||
async def test_automation_variables(hass, caplog):
|
||||
"""Test automation variables."""
|
||||
calls = async_mock_service(hass, "test", "automation")
|
||||
|
||||
@ -1172,6 +1172,15 @@ async def test_automation_variables(hass):
|
||||
"service": "test.automation",
|
||||
},
|
||||
},
|
||||
{
|
||||
"variables": {
|
||||
"test_var": "{{ trigger.event.data.break + 1 }}",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event_3"},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
@ -1188,3 +1197,13 @@ async def test_automation_variables(hass):
|
||||
hass.bus.async_fire("test_event_2", {"pass_condition": True})
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
|
||||
assert "Error rendering variables" not in caplog.text
|
||||
hass.bus.async_fire("test_event_3")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert "Error rendering variables" in caplog.text
|
||||
|
||||
hass.bus.async_fire("test_event_3", {"break": 0})
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 3
|
||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import Context, callback, split_entity_id
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -617,7 +618,7 @@ async def test_concurrent_script(hass, concurrently):
|
||||
assert not script.is_on(hass, "script.script2")
|
||||
|
||||
|
||||
async def test_script_variables(hass):
|
||||
async def test_script_variables(hass, caplog):
|
||||
"""Test defining scripts."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
@ -652,6 +653,19 @@ async def test_script_variables(hass):
|
||||
},
|
||||
],
|
||||
},
|
||||
"script3": {
|
||||
"variables": {
|
||||
"test_var": "{{ break + 1 }}",
|
||||
},
|
||||
"sequence": [
|
||||
{
|
||||
"service": "test.script",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
@ -681,3 +695,14 @@ async def test_script_variables(hass):
|
||||
|
||||
assert len(mock_calls) == 3
|
||||
assert mock_calls[2].data["value"] == "from_service"
|
||||
|
||||
assert "Error rendering variables" not in caplog.text
|
||||
with pytest.raises(template.TemplateError):
|
||||
await hass.services.async_call("script", "script3", blocking=True)
|
||||
assert "Error rendering variables" in caplog.text
|
||||
assert len(mock_calls) == 3
|
||||
|
||||
await hass.services.async_call("script", "script3", {"break": 0}, blocking=True)
|
||||
|
||||
assert len(mock_calls) == 4
|
||||
assert mock_calls[3].data["value"] == "1"
|
||||
|
60
tests/helpers/test_script_variables.py
Normal file
60
tests/helpers/test_script_variables.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Test script variables."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
|
||||
|
||||
async def test_static_vars():
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
rendered = var.async_render(None, None)
|
||||
assert rendered is not orig
|
||||
assert rendered == orig
|
||||
|
||||
|
||||
async def test_static_vars_run_args():
|
||||
"""Test static vars."""
|
||||
orig = {"hello": "world"}
|
||||
orig_copy = dict(orig)
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||
rendered = var.async_render(None, {"hello": "override", "run": "var"})
|
||||
assert rendered == {"hello": "override", "run": "var"}
|
||||
# Make sure we don't change original vars
|
||||
assert orig == orig_copy
|
||||
|
||||
|
||||
async def test_template_vars(hass):
|
||||
"""Test template vars."""
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
|
||||
rendered = var.async_render(hass, None)
|
||||
assert rendered == {"hello": "2"}
|
||||
|
||||
|
||||
async def test_template_vars_run_args(hass):
|
||||
"""Test template vars."""
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA(
|
||||
{
|
||||
"something": "{{ run_var_ex + 1 }}",
|
||||
"something_2": "{{ run_var_ex + 1 }}",
|
||||
}
|
||||
)
|
||||
rendered = var.async_render(
|
||||
hass,
|
||||
{
|
||||
"run_var_ex": 5,
|
||||
"something_2": 1,
|
||||
},
|
||||
)
|
||||
assert rendered == {
|
||||
"run_var_ex": 5,
|
||||
"something": "6",
|
||||
"something_2": 1,
|
||||
}
|
||||
|
||||
|
||||
async def test_template_vars_error(hass):
|
||||
"""Test template vars."""
|
||||
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
|
||||
with pytest.raises(template.TemplateError):
|
||||
var.async_render(hass, None)
|
Loading…
x
Reference in New Issue
Block a user