From 8ef04268be1cb2c5d2bc489d62bef79e1b51214f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 11 Sep 2020 12:24:16 +0200 Subject: [PATCH] Extract variable rendering (#39934) --- .../components/automation/__init__.py | 21 +++---- homeassistant/components/script/__init__.py | 2 +- homeassistant/helpers/config_validation.py | 11 +++- homeassistant/helpers/script.py | 19 +++--- homeassistant/helpers/script_variables.py | 57 ++++++++++++++++++ tests/components/automation/test_init.py | 21 ++++++- tests/components/script/test_init.py | 27 ++++++++- tests/helpers/test_script_variables.py | 60 +++++++++++++++++++ 8 files changed, 193 insertions(+), 25 deletions(-) create mode 100644 homeassistant/helpers/script_variables.py create mode 100644 tests/helpers/test_script_variables.py diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 392ca710000..dff751956a7 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -45,6 +45,7 @@ from homeassistant.helpers.script import ( Script, make_script_schema, ) +from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.trigger import async_initialize_triggers from homeassistant.helpers.typing import TemplateVarsType @@ -256,8 +257,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None self._logger = _LOGGER - self._variables = variables - self._variables_dynamic = template.is_complex(variables) + self._variables: ScriptVariables = variables @property def name(self): @@ -334,9 +334,6 @@ class AutomationEntity(ToggleEntity, RestoreEntity): """Startup with initial state or previous state.""" await super().async_added_to_hass() - if self._variables_dynamic: - template.attach(cast(HomeAssistant, self.hass), self._variables) - self._logger = logging.getLogger( f"{__name__}.{split_entity_id(self.entity_id)[1]}" ) @@ -392,15 +389,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity): This method is a coroutine. """ if self._variables: - if self._variables_dynamic: - variables = template.render_complex(self._variables, run_variables) - else: - variables = dict(self._variables) + try: + variables = self._variables.async_render(self.hass, run_variables) + except template.TemplateError as err: + self._logger.error("Error rendering variables: %s", err) + return else: - variables = {} - - if run_variables: - variables.update(run_variables) + variables = run_variables if ( not skip_condition diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 1e0fad9be5d..eab30e01ee2 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -77,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema( SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema( - {vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA} + {vol.Optional(ATTR_VARIABLES): {str: cv.match_all}} ) RELOAD_SERVICE_SCHEMA = vol.Schema({}) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index a54f97ec7e5..602a8ebfd2a 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -81,7 +81,10 @@ from homeassistant.const import ( ) from homeassistant.core import split_entity_id, valid_entity_id from homeassistant.exceptions import TemplateError -from homeassistant.helpers import template as template_helper +from homeassistant.helpers import ( + script_variables as script_variables_helper, + template as template_helper, +) from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.util import slugify as util_slugify import homeassistant.util.dt as dt_util @@ -863,7 +866,11 @@ def make_entity_service_schema( ) -SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex}) +SCRIPT_VARIABLES_SCHEMA = vol.All( + vol.Schema({str: template_complex}), + # pylint: disable=unnecessary-lambda + lambda val: script_variables_helper.ScriptVariables(val), +) def script_action(value: Any) -> dict: diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index cd664974431..bd1442587eb 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -55,6 +55,7 @@ from homeassistant.const import ( from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback from homeassistant.helpers import condition, config_validation as cv, template from homeassistant.helpers.event import async_call_later, async_track_template +from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.service import ( CONF_SERVICE_DATA, async_prepare_call_from_config, @@ -717,7 +718,7 @@ class Script: logger: Optional[logging.Logger] = None, log_exceptions: bool = True, top_level: bool = True, - variables: Optional[Dict[str, Any]] = None, + variables: Optional[ScriptVariables] = None, ) -> None: """Initialize the script.""" all_scripts = hass.data.get(DATA_SCRIPTS) @@ -900,15 +901,19 @@ class Script: # during the run back to the caller. if self._top_level: if self.variables: - if self._variables_dynamic: - variables = template.render_complex(self.variables, run_variables) - else: - variables = dict(self.variables) + try: + variables = self.variables.async_render( + self._hass, + run_variables, + ) + except template.TemplateError as err: + self._log("Error rendering variables: %s", err, level=logging.ERROR) + raise + elif run_variables: + variables = dict(run_variables) else: variables = {} - if run_variables: - variables.update(run_variables) variables["context"] = context else: variables = cast(dict, run_variables) diff --git a/homeassistant/helpers/script_variables.py b/homeassistant/helpers/script_variables.py new file mode 100644 index 00000000000..001c3b8667c --- /dev/null +++ b/homeassistant/helpers/script_variables.py @@ -0,0 +1,57 @@ +"""Script variables.""" +from typing import Any, Dict, Mapping, Optional + +from homeassistant.core import HomeAssistant, callback + +from . import template + + +class ScriptVariables: + """Class to hold and render script variables.""" + + def __init__(self, variables: Dict[str, Any]): + """Initialize script variables.""" + self.variables = variables + self._has_template: Optional[bool] = None + + @callback + def async_render( + self, + hass: HomeAssistant, + run_variables: Optional[Mapping[str, Any]], + ) -> Dict[str, Any]: + """Render script variables. + + The run variables are used to compute the static variables, but afterwards will also + be merged on top of the static variables. + """ + if self._has_template is None: + self._has_template = template.is_complex(self.variables) + template.attach(hass, self.variables) + + if not self._has_template: + rendered_variables = dict(self.variables) + + if run_variables is not None: + rendered_variables.update(run_variables) + + return rendered_variables + + rendered_variables = {} if run_variables is None else dict(run_variables) + + for key, value in self.variables.items(): + # We can skip if we're going to override this key with + # run variables anyway + if key in rendered_variables: + continue + + rendered_variables[key] = template.render_complex(value, rendered_variables) + + if run_variables: + rendered_variables.update(run_variables) + + return rendered_variables + + def as_dict(self) -> dict: + """Return dict version of this class.""" + return self.variables diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 5ee0ff62af2..9c38574945d 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1136,7 +1136,7 @@ async def test_logbook_humanify_automation_triggered_event(hass): assert event2["entity_id"] == "automation.bye" -async def test_automation_variables(hass): +async def test_automation_variables(hass, caplog): """Test automation variables.""" calls = async_mock_service(hass, "test", "automation") @@ -1172,6 +1172,15 @@ async def test_automation_variables(hass): "service": "test.automation", }, }, + { + "variables": { + "test_var": "{{ trigger.event.data.break + 1 }}", + }, + "trigger": {"platform": "event", "event_type": "test_event_3"}, + "action": { + "service": "test.automation", + }, + }, ] }, ) @@ -1188,3 +1197,13 @@ async def test_automation_variables(hass): hass.bus.async_fire("test_event_2", {"pass_condition": True}) await hass.async_block_till_done() assert len(calls) == 2 + + assert "Error rendering variables" not in caplog.text + hass.bus.async_fire("test_event_3") + await hass.async_block_till_done() + assert len(calls) == 2 + assert "Error rendering variables" in caplog.text + + hass.bus.async_fire("test_event_3", {"break": 0}) + await hass.async_block_till_done() + assert len(calls) == 3 diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index 5fb832d0f36..152c74d8fe9 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -17,6 +17,7 @@ from homeassistant.const import ( ) from homeassistant.core import Context, callback, split_entity_id from homeassistant.exceptions import ServiceNotFound +from homeassistant.helpers import template from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import bind_hass @@ -617,7 +618,7 @@ async def test_concurrent_script(hass, concurrently): assert not script.is_on(hass, "script.script2") -async def test_script_variables(hass): +async def test_script_variables(hass, caplog): """Test defining scripts.""" assert await async_setup_component( hass, @@ -652,6 +653,19 @@ async def test_script_variables(hass): }, ], }, + "script3": { + "variables": { + "test_var": "{{ break + 1 }}", + }, + "sequence": [ + { + "service": "test.script", + "data": { + "value": "{{ test_var }}", + }, + }, + ], + }, } }, ) @@ -681,3 +695,14 @@ async def test_script_variables(hass): assert len(mock_calls) == 3 assert mock_calls[2].data["value"] == "from_service" + + assert "Error rendering variables" not in caplog.text + with pytest.raises(template.TemplateError): + await hass.services.async_call("script", "script3", blocking=True) + assert "Error rendering variables" in caplog.text + assert len(mock_calls) == 3 + + await hass.services.async_call("script", "script3", {"break": 0}, blocking=True) + + assert len(mock_calls) == 4 + assert mock_calls[3].data["value"] == "1" diff --git a/tests/helpers/test_script_variables.py b/tests/helpers/test_script_variables.py new file mode 100644 index 00000000000..6e671d14a23 --- /dev/null +++ b/tests/helpers/test_script_variables.py @@ -0,0 +1,60 @@ +"""Test script variables.""" +import pytest + +from homeassistant.helpers import config_validation as cv, template + + +async def test_static_vars(): + """Test static vars.""" + orig = {"hello": "world"} + var = cv.SCRIPT_VARIABLES_SCHEMA(orig) + rendered = var.async_render(None, None) + assert rendered is not orig + assert rendered == orig + + +async def test_static_vars_run_args(): + """Test static vars.""" + orig = {"hello": "world"} + orig_copy = dict(orig) + var = cv.SCRIPT_VARIABLES_SCHEMA(orig) + rendered = var.async_render(None, {"hello": "override", "run": "var"}) + assert rendered == {"hello": "override", "run": "var"} + # Make sure we don't change original vars + assert orig == orig_copy + + +async def test_template_vars(hass): + """Test template vars.""" + var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"}) + rendered = var.async_render(hass, None) + assert rendered == {"hello": "2"} + + +async def test_template_vars_run_args(hass): + """Test template vars.""" + var = cv.SCRIPT_VARIABLES_SCHEMA( + { + "something": "{{ run_var_ex + 1 }}", + "something_2": "{{ run_var_ex + 1 }}", + } + ) + rendered = var.async_render( + hass, + { + "run_var_ex": 5, + "something_2": 1, + }, + ) + assert rendered == { + "run_var_ex": 5, + "something": "6", + "something_2": 1, + } + + +async def test_template_vars_error(hass): + """Test template vars.""" + var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"}) + with pytest.raises(template.TemplateError): + var.async_render(hass, None)