mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 02:07:54 +00:00
Support templating MQTT triggers (#45614)
* Add support for limited templates (no HASS access) * Pass variables to automation triggers * Support templates in MQTT triggers * Spelling * Handle trigger referenced by variables * Raise on unsupported function in limited templates * Validate MQTT trigger schema in MQTT device trigger * Add trigger_variables to automation config schema * Don't print stacktrace when setting up trigger throws * Make pylint happy * Add trigger_variables to variables * Add debug prints, document limited template * Add tests * Validate MQTT trigger topic early when possible * Improve valid_subscribe_topic_template
This commit is contained in:
parent
b9b1caf4d7
commit
047f16772f
@ -60,6 +60,7 @@ from .const import (
|
||||
CONF_ACTION,
|
||||
CONF_INITIAL_STATE,
|
||||
CONF_TRIGGER,
|
||||
CONF_TRIGGER_VARIABLES,
|
||||
DEFAULT_INITIAL_STATE,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
@ -221,6 +222,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
action_script,
|
||||
initial_state,
|
||||
variables,
|
||||
trigger_variables,
|
||||
):
|
||||
"""Initialize an automation entity."""
|
||||
self._id = automation_id
|
||||
@ -236,6 +238,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
self._logger = LOGGER
|
||||
self._variables: ScriptVariables = variables
|
||||
self._trigger_variables: ScriptVariables = trigger_variables
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -471,6 +474,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
def log_cb(level, msg, **kwargs):
|
||||
self._logger.log(level, "%s %s", msg, self._name, **kwargs)
|
||||
|
||||
variables = None
|
||||
if self._trigger_variables:
|
||||
try:
|
||||
variables = self._trigger_variables.async_render(
|
||||
cast(HomeAssistant, self.hass), None, limited=True
|
||||
)
|
||||
except template.TemplateError as err:
|
||||
self._logger.error("Error rendering trigger variables: %s", err)
|
||||
return None
|
||||
|
||||
return await async_initialize_triggers(
|
||||
cast(HomeAssistant, self.hass),
|
||||
self._trigger_config,
|
||||
@ -479,6 +492,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
self._name,
|
||||
log_cb,
|
||||
home_assistant_start,
|
||||
variables,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -556,6 +570,18 @@ async def _async_process_config(
|
||||
else:
|
||||
cond_func = None
|
||||
|
||||
# Add trigger variables to variables
|
||||
variables = None
|
||||
if CONF_TRIGGER_VARIABLES in config_block:
|
||||
variables = ScriptVariables(
|
||||
dict(config_block[CONF_TRIGGER_VARIABLES].as_dict())
|
||||
)
|
||||
if CONF_VARIABLES in config_block:
|
||||
if variables:
|
||||
variables.variables.update(config_block[CONF_VARIABLES].as_dict())
|
||||
else:
|
||||
variables = config_block[CONF_VARIABLES]
|
||||
|
||||
entity = AutomationEntity(
|
||||
automation_id,
|
||||
name,
|
||||
@ -563,7 +589,8 @@ async def _async_process_config(
|
||||
cond_func,
|
||||
action_script,
|
||||
initial_state,
|
||||
config_block.get(CONF_VARIABLES),
|
||||
variables,
|
||||
config_block.get(CONF_TRIGGER_VARIABLES),
|
||||
)
|
||||
|
||||
entities.append(entity)
|
||||
|
@ -21,6 +21,7 @@ from .const import (
|
||||
CONF_HIDE_ENTITY,
|
||||
CONF_INITIAL_STATE,
|
||||
CONF_TRIGGER,
|
||||
CONF_TRIGGER_VARIABLES,
|
||||
DOMAIN,
|
||||
)
|
||||
from .helpers import async_get_blueprints
|
||||
@ -43,6 +44,7 @@ PLATFORM_SCHEMA = vol.All(
|
||||
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
|
||||
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
|
||||
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
|
||||
vol.Optional(CONF_TRIGGER_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
|
||||
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
|
||||
},
|
||||
script.SCRIPT_MODE_SINGLE,
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
CONF_ACTION = "action"
|
||||
CONF_TRIGGER = "trigger"
|
||||
CONF_TRIGGER_VARIABLES = "trigger_variables"
|
||||
DOMAIN = "automation"
|
||||
|
||||
CONF_DESCRIPTION = "description"
|
||||
|
@ -89,12 +89,14 @@ class TriggerInstance:
|
||||
async def async_attach_trigger(self):
|
||||
"""Attach MQTT trigger."""
|
||||
mqtt_config = {
|
||||
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
|
||||
mqtt_trigger.CONF_TOPIC: self.trigger.topic,
|
||||
mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING,
|
||||
mqtt_trigger.CONF_QOS: self.trigger.qos,
|
||||
}
|
||||
if self.trigger.payload:
|
||||
mqtt_config[CONF_PAYLOAD] = self.trigger.payload
|
||||
mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config)
|
||||
|
||||
if self.remove:
|
||||
self.remove()
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Offer MQTT listening automation rules."""
|
||||
import json
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM
|
||||
from homeassistant.core import HassJob, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
|
||||
from .. import mqtt
|
||||
|
||||
@ -20,8 +21,8 @@ DEFAULT_QOS = 0
|
||||
TRIGGER_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
|
||||
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
|
||||
vol.Optional(CONF_PAYLOAD): cv.string,
|
||||
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic_template,
|
||||
vol.Optional(CONF_PAYLOAD): cv.template,
|
||||
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
|
||||
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
|
||||
vol.Coerce(int), vol.In([0, 1, 2])
|
||||
@ -29,6 +30,8 @@ TRIGGER_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_attach_trigger(hass, config, action, automation_info):
|
||||
"""Listen for state changes based on configuration."""
|
||||
@ -37,6 +40,18 @@ async def async_attach_trigger(hass, config, action, automation_info):
|
||||
encoding = config[CONF_ENCODING] or None
|
||||
qos = config[CONF_QOS]
|
||||
job = HassJob(action)
|
||||
variables = None
|
||||
if automation_info:
|
||||
variables = automation_info.get("variables")
|
||||
|
||||
template.attach(hass, payload)
|
||||
if payload:
|
||||
payload = payload.async_render(variables, limited=True)
|
||||
|
||||
template.attach(hass, topic)
|
||||
if isinstance(topic, template.Template):
|
||||
topic = topic.async_render(variables, limited=True)
|
||||
topic = mqtt.util.valid_subscribe_topic(topic)
|
||||
|
||||
@callback
|
||||
def mqtt_automation_listener(mqttmsg):
|
||||
@ -57,6 +72,10 @@ async def async_attach_trigger(hass, config, action, automation_info):
|
||||
|
||||
hass.async_run_hass_job(job, {"trigger": data})
|
||||
|
||||
_LOGGER.debug(
|
||||
"Attaching MQTT trigger for topic: '%s', payload: '%s'", topic, payload
|
||||
)
|
||||
|
||||
remove = await mqtt.async_subscribe(
|
||||
hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from typing import Any
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_PAYLOAD
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
|
||||
from .const import (
|
||||
ATTR_PAYLOAD,
|
||||
@ -61,6 +61,16 @@ def valid_subscribe_topic(value: Any) -> str:
|
||||
return value
|
||||
|
||||
|
||||
def valid_subscribe_topic_template(value: Any) -> template.Template:
|
||||
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
|
||||
tpl = template.Template(value)
|
||||
|
||||
if tpl.is_static:
|
||||
valid_subscribe_topic(value)
|
||||
|
||||
return tpl
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
|
@ -572,7 +572,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
|
||||
if isinstance(value, (list, dict, template_helper.Template)):
|
||||
raise vol.Invalid("template value should be a string")
|
||||
if not template_helper.is_template_string(str(value)):
|
||||
raise vol.Invalid("template value does not contain a dynmamic template")
|
||||
raise vol.Invalid("template value does not contain a dynamic template")
|
||||
|
||||
template_value = template_helper.Template(str(value)) # type: ignore
|
||||
try:
|
||||
|
@ -21,6 +21,7 @@ class ScriptVariables:
|
||||
run_variables: Optional[Mapping[str, Any]],
|
||||
*,
|
||||
render_as_defaults: bool = True,
|
||||
limited: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Render script variables.
|
||||
|
||||
@ -55,7 +56,9 @@ class ScriptVariables:
|
||||
if render_as_defaults and key in rendered_variables:
|
||||
continue
|
||||
|
||||
rendered_variables[key] = template.render_complex(value, rendered_variables)
|
||||
rendered_variables[key] = template.render_complex(
|
||||
value, rendered_variables, limited
|
||||
)
|
||||
|
||||
return rendered_variables
|
||||
|
||||
|
@ -84,7 +84,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
|
||||
obj.hass = hass
|
||||
|
||||
|
||||
def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
|
||||
def render_complex(
|
||||
value: Any, variables: TemplateVarsType = None, limited: bool = False
|
||||
) -> Any:
|
||||
"""Recursive template creator helper function."""
|
||||
if isinstance(value, list):
|
||||
return [render_complex(item, variables) for item in value]
|
||||
@ -94,7 +96,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, Template):
|
||||
return value.async_render(variables)
|
||||
return value.async_render(variables, limited=limited)
|
||||
|
||||
return value
|
||||
|
||||
@ -279,6 +281,7 @@ class Template:
|
||||
"is_static",
|
||||
"_compiled_code",
|
||||
"_compiled",
|
||||
"_limited",
|
||||
)
|
||||
|
||||
def __init__(self, template, hass=None):
|
||||
@ -291,10 +294,11 @@ class Template:
|
||||
self._compiled: Optional[Template] = None
|
||||
self.hass = hass
|
||||
self.is_static = not is_template_string(template)
|
||||
self._limited = None
|
||||
|
||||
@property
|
||||
def _env(self) -> "TemplateEnvironment":
|
||||
if self.hass is None:
|
||||
if self.hass is None or self._limited:
|
||||
return _NO_HASS_ENV
|
||||
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
|
||||
if ret is None:
|
||||
@ -315,9 +319,13 @@ class Template:
|
||||
self,
|
||||
variables: TemplateVarsType = None,
|
||||
parse_result: bool = True,
|
||||
limited: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Render given template."""
|
||||
"""Render given template.
|
||||
|
||||
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
|
||||
"""
|
||||
if self.is_static:
|
||||
if self.hass.config.legacy_templates or not parse_result:
|
||||
return self.template
|
||||
@ -325,7 +333,7 @@ class Template:
|
||||
|
||||
return run_callback_threadsafe(
|
||||
self.hass.loop,
|
||||
partial(self.async_render, variables, parse_result, **kwargs),
|
||||
partial(self.async_render, variables, parse_result, limited, **kwargs),
|
||||
).result()
|
||||
|
||||
@callback
|
||||
@ -333,18 +341,21 @@ class Template:
|
||||
self,
|
||||
variables: TemplateVarsType = None,
|
||||
parse_result: bool = True,
|
||||
limited: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Render given template.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
||||
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
|
||||
"""
|
||||
if self.is_static:
|
||||
if self.hass.config.legacy_templates or not parse_result:
|
||||
return self.template
|
||||
return self._parse_result(self.template)
|
||||
|
||||
compiled = self._compiled or self._ensure_compiled()
|
||||
compiled = self._compiled or self._ensure_compiled(limited)
|
||||
|
||||
if variables is not None:
|
||||
kwargs.update(variables)
|
||||
@ -519,12 +530,16 @@ class Template:
|
||||
)
|
||||
return value if error_value is _SENTINEL else error_value
|
||||
|
||||
def _ensure_compiled(self) -> "Template":
|
||||
def _ensure_compiled(self, limited: bool = False) -> "Template":
|
||||
"""Bind a template to a specific hass instance."""
|
||||
self.ensure_valid()
|
||||
|
||||
assert self.hass is not None, "hass variable not set on template"
|
||||
assert (
|
||||
self._limited is None or self._limited == limited
|
||||
), "can't change between limited and non limited template"
|
||||
|
||||
self._limited = limited
|
||||
env = self._env
|
||||
|
||||
self._compiled = cast(
|
||||
@ -1352,6 +1367,31 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
|
||||
self.globals["strptime"] = strptime
|
||||
self.globals["urlencode"] = urlencode
|
||||
if hass is None:
|
||||
|
||||
def unsupported(name):
|
||||
def warn_unsupported(*args, **kwargs):
|
||||
raise TemplateError(
|
||||
f"Use of '{name}' is not supported in limited templates"
|
||||
)
|
||||
|
||||
return warn_unsupported
|
||||
|
||||
hass_globals = [
|
||||
"closest",
|
||||
"distance",
|
||||
"expand",
|
||||
"is_state",
|
||||
"is_state_attr",
|
||||
"state_attr",
|
||||
"states",
|
||||
"utcnow",
|
||||
"now",
|
||||
]
|
||||
hass_filters = ["closest", "expand"]
|
||||
for glob in hass_globals:
|
||||
self.globals[glob] = unsupported(glob)
|
||||
for filt in hass_filters:
|
||||
self.filters[filt] = unsupported(filt)
|
||||
return
|
||||
|
||||
# We mark these as a context functions to ensure they get
|
||||
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.core import CALLBACK_TYPE, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
|
||||
@ -79,7 +80,9 @@ async def async_initialize_triggers(
|
||||
removes = []
|
||||
|
||||
for result in attach_results:
|
||||
if isinstance(result, Exception):
|
||||
if isinstance(result, HomeAssistantError):
|
||||
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
|
||||
elif isinstance(result, Exception):
|
||||
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
|
||||
elif result is None:
|
||||
log_cb(
|
||||
|
@ -1237,6 +1237,94 @@ async def test_automation_variables(hass, caplog):
|
||||
assert len(calls) == 3
|
||||
|
||||
|
||||
async def test_automation_trigger_variables(hass, caplog):
|
||||
"""Test automation trigger variables."""
|
||||
calls = async_mock_service(hass, "test", "automation")
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: [
|
||||
{
|
||||
"variables": {
|
||||
"event_type": "{{ trigger.event.event_type }}",
|
||||
},
|
||||
"trigger_variables": {
|
||||
"test_var": "defined_in_config",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
"event_type": "{{ event_type }}",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"variables": {
|
||||
"event_type": "{{ trigger.event.event_type }}",
|
||||
"test_var": "overridden_in_config",
|
||||
},
|
||||
"trigger_variables": {
|
||||
"test_var": "defined_in_config",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event_2"},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
"event_type": "{{ event_type }}",
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["value"] == "defined_in_config"
|
||||
assert calls[0].data["event_type"] == "test_event"
|
||||
|
||||
hass.bus.async_fire("test_event_2")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[1].data["value"] == "overridden_in_config"
|
||||
assert calls[1].data["event_type"] == "test_event_2"
|
||||
|
||||
assert "Error rendering variables" not in caplog.text
|
||||
|
||||
|
||||
async def test_automation_bad_trigger_variables(hass, caplog):
|
||||
"""Test automation trigger variables accessing hass is rejected."""
|
||||
calls = async_mock_service(hass, "test", "automation")
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: [
|
||||
{
|
||||
"trigger_variables": {
|
||||
"test_var": "{{ states('foo.bar') }}",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
hass.bus.async_fire("test_event")
|
||||
assert "Use of 'states' is not supported in limited templates" in caplog.text
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
async def test_blueprint_automation(hass, calls):
|
||||
"""Test blueprint automation."""
|
||||
assert await async_setup_component(
|
||||
|
@ -81,6 +81,58 @@ async def test_if_fires_on_topic_and_payload_match(hass, calls):
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_if_fires_on_templated_topic_and_payload_match(hass, calls):
|
||||
"""Test if message is fired on templated topic and payload match."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"trigger": {
|
||||
"platform": "mqtt",
|
||||
"topic": "test-topic-{{ sqrt(16)|round }}",
|
||||
"payload": '{{ "foo"|regex_replace("foo", "bar") }}',
|
||||
},
|
||||
"action": {"service": "test.automation"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic-", "foo")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 0
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic-4", "foo")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 0
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic-4", "bar")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_non_allowed_templates(hass, calls, caplog):
|
||||
"""Test non allowed function in template."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"trigger": {
|
||||
"platform": "mqtt",
|
||||
"topic": "test-topic-{{ states() }}",
|
||||
},
|
||||
"action": {"service": "test.automation"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert (
|
||||
"Got error 'TemplateError: str: Use of 'states' is not supported in limited templates' when setting up triggers"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_if_not_fires_on_topic_but_no_payload_match(hass, calls):
|
||||
"""Test if message is not fired on topic but no payload."""
|
||||
assert await async_setup_component(
|
||||
|
Loading…
x
Reference in New Issue
Block a user