Extract variable rendering (#39934)

This commit is contained in:
Paulus Schoutsen 2020-09-11 12:24:16 +02:00
parent b107e87d38
commit 8ef04268be
8 changed files with 193 additions and 25 deletions

View File

@ -45,6 +45,7 @@ from homeassistant.helpers.script import (
Script,
make_script_schema,
)
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.trigger import async_initialize_triggers
from homeassistant.helpers.typing import TemplateVarsType
@ -256,8 +257,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
self._logger = _LOGGER
self._variables = variables
self._variables_dynamic = template.is_complex(variables)
self._variables: ScriptVariables = variables
@property
def name(self):
@ -334,9 +334,6 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
"""Startup with initial state or previous state."""
await super().async_added_to_hass()
if self._variables_dynamic:
template.attach(cast(HomeAssistant, self.hass), self._variables)
self._logger = logging.getLogger(
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
)
@ -392,15 +389,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
This method is a coroutine.
"""
if self._variables:
if self._variables_dynamic:
variables = template.render_complex(self._variables, run_variables)
else:
variables = dict(self._variables)
try:
variables = self._variables.async_render(self.hass, run_variables)
except template.TemplateError as err:
self._logger.error("Error rendering variables: %s", err)
return
else:
variables = {}
if run_variables:
variables.update(run_variables)
variables = run_variables
if (
not skip_condition

View File

@ -77,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
{vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
{vol.Optional(ATTR_VARIABLES): {str: cv.match_all}}
)
RELOAD_SERVICE_SCHEMA = vol.Schema({})

View File

@ -81,7 +81,10 @@ from homeassistant.const import (
)
from homeassistant.core import split_entity_id, valid_entity_id
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template as template_helper
from homeassistant.helpers import (
script_variables as script_variables_helper,
template as template_helper,
)
from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import slugify as util_slugify
import homeassistant.util.dt as dt_util
@ -863,7 +866,11 @@ def make_entity_service_schema(
)
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
SCRIPT_VARIABLES_SCHEMA = vol.All(
vol.Schema({str: template_complex}),
# pylint: disable=unnecessary-lambda
lambda val: script_variables_helper.ScriptVariables(val),
)
def script_action(value: Any) -> dict:

View File

@ -55,6 +55,7 @@ from homeassistant.const import (
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
from homeassistant.helpers import condition, config_validation as cv, template
from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
@ -717,7 +718,7 @@ class Script:
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
top_level: bool = True,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[ScriptVariables] = None,
) -> None:
"""Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS)
@ -900,15 +901,19 @@ class Script:
# during the run back to the caller.
if self._top_level:
if self.variables:
if self._variables_dynamic:
variables = template.render_complex(self.variables, run_variables)
else:
variables = dict(self.variables)
try:
variables = self.variables.async_render(
self._hass,
run_variables,
)
except template.TemplateError as err:
self._log("Error rendering variables: %s", err, level=logging.ERROR)
raise
elif run_variables:
variables = dict(run_variables)
else:
variables = {}
if run_variables:
variables.update(run_variables)
variables["context"] = context
else:
variables = cast(dict, run_variables)

View File

@ -0,0 +1,57 @@
"""Script variables."""
from typing import Any, Dict, Mapping, Optional
from homeassistant.core import HomeAssistant, callback
from . import template
class ScriptVariables:
"""Class to hold and render script variables."""
def __init__(self, variables: Dict[str, Any]):
"""Initialize script variables."""
self.variables = variables
self._has_template: Optional[bool] = None
@callback
def async_render(
self,
hass: HomeAssistant,
run_variables: Optional[Mapping[str, Any]],
) -> Dict[str, Any]:
"""Render script variables.
The run variables are used to compute the static variables, but afterwards will also
be merged on top of the static variables.
"""
if self._has_template is None:
self._has_template = template.is_complex(self.variables)
template.attach(hass, self.variables)
if not self._has_template:
rendered_variables = dict(self.variables)
if run_variables is not None:
rendered_variables.update(run_variables)
return rendered_variables
rendered_variables = {} if run_variables is None else dict(run_variables)
for key, value in self.variables.items():
# We can skip if we're going to override this key with
# run variables anyway
if key in rendered_variables:
continue
rendered_variables[key] = template.render_complex(value, rendered_variables)
if run_variables:
rendered_variables.update(run_variables)
return rendered_variables
def as_dict(self) -> dict:
"""Return dict version of this class."""
return self.variables

View File

@ -1136,7 +1136,7 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["entity_id"] == "automation.bye"
async def test_automation_variables(hass):
async def test_automation_variables(hass, caplog):
"""Test automation variables."""
calls = async_mock_service(hass, "test", "automation")
@ -1172,6 +1172,15 @@ async def test_automation_variables(hass):
"service": "test.automation",
},
},
{
"variables": {
"test_var": "{{ trigger.event.data.break + 1 }}",
},
"trigger": {"platform": "event", "event_type": "test_event_3"},
"action": {
"service": "test.automation",
},
},
]
},
)
@ -1188,3 +1197,13 @@ async def test_automation_variables(hass):
hass.bus.async_fire("test_event_2", {"pass_condition": True})
await hass.async_block_till_done()
assert len(calls) == 2
assert "Error rendering variables" not in caplog.text
hass.bus.async_fire("test_event_3")
await hass.async_block_till_done()
assert len(calls) == 2
assert "Error rendering variables" in caplog.text
hass.bus.async_fire("test_event_3", {"break": 0})
await hass.async_block_till_done()
assert len(calls) == 3

View File

@ -17,6 +17,7 @@ from homeassistant.const import (
)
from homeassistant.core import Context, callback, split_entity_id
from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import template
from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import bind_hass
@ -617,7 +618,7 @@ async def test_concurrent_script(hass, concurrently):
assert not script.is_on(hass, "script.script2")
async def test_script_variables(hass):
async def test_script_variables(hass, caplog):
"""Test defining scripts."""
assert await async_setup_component(
hass,
@ -652,6 +653,19 @@ async def test_script_variables(hass):
},
],
},
"script3": {
"variables": {
"test_var": "{{ break + 1 }}",
},
"sequence": [
{
"service": "test.script",
"data": {
"value": "{{ test_var }}",
},
},
],
},
}
},
)
@ -681,3 +695,14 @@ async def test_script_variables(hass):
assert len(mock_calls) == 3
assert mock_calls[2].data["value"] == "from_service"
assert "Error rendering variables" not in caplog.text
with pytest.raises(template.TemplateError):
await hass.services.async_call("script", "script3", blocking=True)
assert "Error rendering variables" in caplog.text
assert len(mock_calls) == 3
await hass.services.async_call("script", "script3", {"break": 0}, blocking=True)
assert len(mock_calls) == 4
assert mock_calls[3].data["value"] == "1"

View File

@ -0,0 +1,60 @@
"""Test script variables."""
import pytest
from homeassistant.helpers import config_validation as cv, template
async def test_static_vars():
"""Test static vars."""
orig = {"hello": "world"}
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
rendered = var.async_render(None, None)
assert rendered is not orig
assert rendered == orig
async def test_static_vars_run_args():
"""Test static vars."""
orig = {"hello": "world"}
orig_copy = dict(orig)
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
rendered = var.async_render(None, {"hello": "override", "run": "var"})
assert rendered == {"hello": "override", "run": "var"}
# Make sure we don't change original vars
assert orig == orig_copy
async def test_template_vars(hass):
"""Test template vars."""
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
rendered = var.async_render(hass, None)
assert rendered == {"hello": "2"}
async def test_template_vars_run_args(hass):
"""Test template vars."""
var = cv.SCRIPT_VARIABLES_SCHEMA(
{
"something": "{{ run_var_ex + 1 }}",
"something_2": "{{ run_var_ex + 1 }}",
}
)
rendered = var.async_render(
hass,
{
"run_var_ex": 5,
"something_2": 1,
},
)
assert rendered == {
"run_var_ex": 5,
"something": "6",
"something_2": 1,
}
async def test_template_vars_error(hass):
"""Test template vars."""
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
with pytest.raises(template.TemplateError):
var.async_render(hass, None)