diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 94f2cedd58e..ae8c71b4fb8 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -60,6 +60,7 @@ from .const import ( CONF_ACTION, CONF_INITIAL_STATE, CONF_TRIGGER, + CONF_TRIGGER_VARIABLES, DEFAULT_INITIAL_STATE, DOMAIN, LOGGER, @@ -221,6 +222,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): action_script, initial_state, variables, + trigger_variables, ): """Initialize an automation entity.""" self._id = automation_id @@ -236,6 +238,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._referenced_devices: Optional[Set[str]] = None self._logger = LOGGER self._variables: ScriptVariables = variables + self._trigger_variables: ScriptVariables = trigger_variables @property def name(self): @@ -471,6 +474,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity): def log_cb(level, msg, **kwargs): self._logger.log(level, "%s %s", msg, self._name, **kwargs) + variables = None + if self._trigger_variables: + try: + variables = self._trigger_variables.async_render( + cast(HomeAssistant, self.hass), None, limited=True + ) + except template.TemplateError as err: + self._logger.error("Error rendering trigger variables: %s", err) + return None + return await async_initialize_triggers( cast(HomeAssistant, self.hass), self._trigger_config, @@ -479,6 +492,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._name, log_cb, home_assistant_start, + variables, ) @property @@ -556,6 +570,18 @@ async def _async_process_config( else: cond_func = None + # Add trigger variables to variables + variables = None + if CONF_TRIGGER_VARIABLES in config_block: + variables = ScriptVariables( + dict(config_block[CONF_TRIGGER_VARIABLES].as_dict()) + ) + if CONF_VARIABLES in config_block: + if variables: + variables.variables.update(config_block[CONF_VARIABLES].as_dict()) + else: + variables = config_block[CONF_VARIABLES] + entity = AutomationEntity( automation_id, name, @@ -563,7 +589,8 @@ async def _async_process_config( cond_func, action_script, initial_state, - config_block.get(CONF_VARIABLES), + variables, + config_block.get(CONF_TRIGGER_VARIABLES), ) entities.append(entity) diff --git a/homeassistant/components/automation/config.py b/homeassistant/components/automation/config.py index 89d5e184748..32ad92cb86e 100644 --- a/homeassistant/components/automation/config.py +++ b/homeassistant/components/automation/config.py @@ -21,6 +21,7 @@ from .const import ( CONF_HIDE_ENTITY, CONF_INITIAL_STATE, CONF_TRIGGER, + CONF_TRIGGER_VARIABLES, DOMAIN, ) from .helpers import async_get_blueprints @@ -43,6 +44,7 @@ PLATFORM_SCHEMA = vol.All( vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA, + vol.Optional(CONF_TRIGGER_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, }, script.SCRIPT_MODE_SINGLE, diff --git a/homeassistant/components/automation/const.py b/homeassistant/components/automation/const.py index ffb89ba0907..829f78590e0 100644 --- a/homeassistant/components/automation/const.py +++ b/homeassistant/components/automation/const.py @@ -3,6 +3,7 @@ import logging CONF_ACTION = "action" CONF_TRIGGER = "trigger" +CONF_TRIGGER_VARIABLES = "trigger_variables" DOMAIN = "automation" CONF_DESCRIPTION = "description" diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 6a04fd48049..8969072553c 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -89,12 +89,14 @@ class TriggerInstance: async def async_attach_trigger(self): """Attach MQTT trigger.""" mqtt_config = { + mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN, mqtt_trigger.CONF_TOPIC: self.trigger.topic, mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING, mqtt_trigger.CONF_QOS: self.trigger.qos, } if self.trigger.payload: mqtt_config[CONF_PAYLOAD] = self.trigger.payload + mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config) if self.remove: self.remove() diff --git a/homeassistant/components/mqtt/trigger.py b/homeassistant/components/mqtt/trigger.py index 1c96b3de266..a82ea355343 100644 --- a/homeassistant/components/mqtt/trigger.py +++ b/homeassistant/components/mqtt/trigger.py @@ -1,11 +1,12 @@ """Offer MQTT listening automation rules.""" import json +import logging import voluptuous as vol from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM from homeassistant.core import HassJob, callback -import homeassistant.helpers.config_validation as cv +from homeassistant.helpers import config_validation as cv, template from .. import mqtt @@ -20,8 +21,8 @@ DEFAULT_QOS = 0 TRIGGER_SCHEMA = vol.Schema( { vol.Required(CONF_PLATFORM): mqtt.DOMAIN, - vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic, - vol.Optional(CONF_PAYLOAD): cv.string, + vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic_template, + vol.Optional(CONF_PAYLOAD): cv.template, vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( vol.Coerce(int), vol.In([0, 1, 2]) @@ -29,6 +30,8 @@ TRIGGER_SCHEMA = vol.Schema( } ) +_LOGGER = logging.getLogger(__name__) + async def async_attach_trigger(hass, config, action, automation_info): """Listen for state changes based on configuration.""" @@ -37,6 +40,18 @@ async def async_attach_trigger(hass, config, action, automation_info): encoding = config[CONF_ENCODING] or None qos = config[CONF_QOS] job = HassJob(action) + variables = None + if automation_info: + variables = automation_info.get("variables") + + template.attach(hass, payload) + if payload: + payload = payload.async_render(variables, limited=True) + + template.attach(hass, topic) + if isinstance(topic, template.Template): + topic = topic.async_render(variables, limited=True) + topic = mqtt.util.valid_subscribe_topic(topic) @callback def mqtt_automation_listener(mqttmsg): @@ -57,6 +72,10 @@ async def async_attach_trigger(hass, config, action, automation_info): hass.async_run_hass_job(job, {"trigger": data}) + _LOGGER.debug( + "Attaching MQTT trigger for topic: '%s', payload: '%s'", topic, payload + ) + remove = await mqtt.async_subscribe( hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos ) diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index 651fe48fe3d..b8fca50a153 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -4,7 +4,7 @@ from typing import Any import voluptuous as vol from homeassistant.const import CONF_PAYLOAD -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import config_validation as cv, template from .const import ( ATTR_PAYLOAD, @@ -61,6 +61,16 @@ def valid_subscribe_topic(value: Any) -> str: return value +def valid_subscribe_topic_template(value: Any) -> template.Template: + """Validate either a jinja2 template or a valid MQTT subscription topic.""" + tpl = template.Template(value) + + if tpl.is_static: + valid_subscribe_topic(value) + + return tpl + + def valid_publish_topic(value: Any) -> str: """Validate that we can publish using this MQTT topic.""" value = valid_topic(value) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index d47ba30c114..4af4744e509 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -572,7 +572,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template: if isinstance(value, (list, dict, template_helper.Template)): raise vol.Invalid("template value should be a string") if not template_helper.is_template_string(str(value)): - raise vol.Invalid("template value does not contain a dynmamic template") + raise vol.Invalid("template value does not contain a dynamic template") template_value = template_helper.Template(str(value)) # type: ignore try: diff --git a/homeassistant/helpers/script_variables.py b/homeassistant/helpers/script_variables.py index 3140fc4dced..818263c9dd5 100644 --- a/homeassistant/helpers/script_variables.py +++ b/homeassistant/helpers/script_variables.py @@ -21,6 +21,7 @@ class ScriptVariables: run_variables: Optional[Mapping[str, Any]], *, render_as_defaults: bool = True, + limited: bool = False, ) -> Dict[str, Any]: """Render script variables. @@ -55,7 +56,9 @@ class ScriptVariables: if render_as_defaults and key in rendered_variables: continue - rendered_variables[key] = template.render_complex(value, rendered_variables) + rendered_variables[key] = template.render_complex( + value, rendered_variables, limited + ) return rendered_variables diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 5f506c02eef..af63cab10eb 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -84,7 +84,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None: obj.hass = hass -def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: +def render_complex( + value: Any, variables: TemplateVarsType = None, limited: bool = False +) -> Any: """Recursive template creator helper function.""" if isinstance(value, list): return [render_complex(item, variables) for item in value] @@ -94,7 +96,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: for key, item in value.items() } if isinstance(value, Template): - return value.async_render(variables) + return value.async_render(variables, limited=limited) return value @@ -279,6 +281,7 @@ class Template: "is_static", "_compiled_code", "_compiled", + "_limited", ) def __init__(self, template, hass=None): @@ -291,10 +294,11 @@ class Template: self._compiled: Optional[Template] = None self.hass = hass self.is_static = not is_template_string(template) + self._limited = None @property def _env(self) -> "TemplateEnvironment": - if self.hass is None: + if self.hass is None or self._limited: return _NO_HASS_ENV ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT) if ret is None: @@ -315,9 +319,13 @@ class Template: self, variables: TemplateVarsType = None, parse_result: bool = True, + limited: bool = False, **kwargs: Any, ) -> Any: - """Render given template.""" + """Render given template. + + If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine. + """ if self.is_static: if self.hass.config.legacy_templates or not parse_result: return self.template @@ -325,7 +333,7 @@ class Template: return run_callback_threadsafe( self.hass.loop, - partial(self.async_render, variables, parse_result, **kwargs), + partial(self.async_render, variables, parse_result, limited, **kwargs), ).result() @callback @@ -333,18 +341,21 @@ class Template: self, variables: TemplateVarsType = None, parse_result: bool = True, + limited: bool = False, **kwargs: Any, ) -> Any: """Render given template. This method must be run in the event loop. + + If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine. """ if self.is_static: if self.hass.config.legacy_templates or not parse_result: return self.template return self._parse_result(self.template) - compiled = self._compiled or self._ensure_compiled() + compiled = self._compiled or self._ensure_compiled(limited) if variables is not None: kwargs.update(variables) @@ -519,12 +530,16 @@ class Template: ) return value if error_value is _SENTINEL else error_value - def _ensure_compiled(self) -> "Template": + def _ensure_compiled(self, limited: bool = False) -> "Template": """Bind a template to a specific hass instance.""" self.ensure_valid() assert self.hass is not None, "hass variable not set on template" + assert ( + self._limited is None or self._limited == limited + ), "can't change between limited and non limited template" + self._limited = limited env = self._env self._compiled = cast( @@ -1352,6 +1367,31 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): self.globals["strptime"] = strptime self.globals["urlencode"] = urlencode if hass is None: + + def unsupported(name): + def warn_unsupported(*args, **kwargs): + raise TemplateError( + f"Use of '{name}' is not supported in limited templates" + ) + + return warn_unsupported + + hass_globals = [ + "closest", + "distance", + "expand", + "is_state", + "is_state_attr", + "state_attr", + "states", + "utcnow", + "now", + ] + hass_filters = ["closest", "expand"] + for glob in hass_globals: + self.globals[glob] = unsupported(glob) + for filt in hass_filters: + self.filters[filt] = unsupported(filt) return # We mark these as a context functions to ensure they get diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 2c7275a9cc3..58ac71a515e 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -8,6 +8,7 @@ import voluptuous as vol from homeassistant.const import CONF_PLATFORM from homeassistant.core import CALLBACK_TYPE, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -79,7 +80,9 @@ async def async_initialize_triggers( removes = [] for result in attach_results: - if isinstance(result, Exception): + if isinstance(result, HomeAssistantError): + log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for") + elif isinstance(result, Exception): log_cb(logging.ERROR, "Error setting up trigger", exc_info=result) elif result is None: log_cb( diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 0dbc4b2cc69..16d56c84cb0 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1237,6 +1237,94 @@ async def test_automation_variables(hass, caplog): assert len(calls) == 3 +async def test_automation_trigger_variables(hass, caplog): + """Test automation trigger variables.""" + calls = async_mock_service(hass, "test", "automation") + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "variables": { + "event_type": "{{ trigger.event.event_type }}", + }, + "trigger_variables": { + "test_var": "defined_in_config", + }, + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": { + "service": "test.automation", + "data": { + "value": "{{ test_var }}", + "event_type": "{{ event_type }}", + }, + }, + }, + { + "variables": { + "event_type": "{{ trigger.event.event_type }}", + "test_var": "overridden_in_config", + }, + "trigger_variables": { + "test_var": "defined_in_config", + }, + "trigger": {"platform": "event", "event_type": "test_event_2"}, + "action": { + "service": "test.automation", + "data": { + "value": "{{ test_var }}", + "event_type": "{{ event_type }}", + }, + }, + }, + ] + }, + ) + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["value"] == "defined_in_config" + assert calls[0].data["event_type"] == "test_event" + + hass.bus.async_fire("test_event_2") + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].data["value"] == "overridden_in_config" + assert calls[1].data["event_type"] == "test_event_2" + + assert "Error rendering variables" not in caplog.text + + +async def test_automation_bad_trigger_variables(hass, caplog): + """Test automation trigger variables accessing hass is rejected.""" + calls = async_mock_service(hass, "test", "automation") + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger_variables": { + "test_var": "{{ states('foo.bar') }}", + }, + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": { + "service": "test.automation", + }, + }, + ] + }, + ) + hass.bus.async_fire("test_event") + assert "Use of 'states' is not supported in limited templates" in caplog.text + + await hass.async_block_till_done() + assert len(calls) == 0 + + async def test_blueprint_automation(hass, calls): """Test blueprint automation.""" assert await async_setup_component( diff --git a/tests/components/mqtt/test_trigger.py b/tests/components/mqtt/test_trigger.py index b27af2b9bd0..537a4f8dc64 100644 --- a/tests/components/mqtt/test_trigger.py +++ b/tests/components/mqtt/test_trigger.py @@ -81,6 +81,58 @@ async def test_if_fires_on_topic_and_payload_match(hass, calls): assert len(calls) == 1 +async def test_if_fires_on_templated_topic_and_payload_match(hass, calls): + """Test if message is fired on templated topic and payload match.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + "platform": "mqtt", + "topic": "test-topic-{{ sqrt(16)|round }}", + "payload": '{{ "foo"|regex_replace("foo", "bar") }}', + }, + "action": {"service": "test.automation"}, + } + }, + ) + + async_fire_mqtt_message(hass, "test-topic-", "foo") + await hass.async_block_till_done() + assert len(calls) == 0 + + async_fire_mqtt_message(hass, "test-topic-4", "foo") + await hass.async_block_till_done() + assert len(calls) == 0 + + async_fire_mqtt_message(hass, "test-topic-4", "bar") + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_non_allowed_templates(hass, calls, caplog): + """Test non allowed function in template.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + "platform": "mqtt", + "topic": "test-topic-{{ states() }}", + }, + "action": {"service": "test.automation"}, + } + }, + ) + + assert ( + "Got error 'TemplateError: str: Use of 'states' is not supported in limited templates' when setting up triggers" + in caplog.text + ) + + async def test_if_not_fires_on_topic_but_no_payload_match(hass, calls): """Test if message is not fired on topic but no payload.""" assert await async_setup_component(