From 19cd29affa5690692c375a81af82288a244e9de1 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sun, 28 Feb 2021 21:19:27 +0100 Subject: [PATCH] Fix MQTT trigger where wanted payload may be parsed as an integer (#47162) --- .../components/mqtt/device_trigger.py | 12 ++- homeassistant/components/mqtt/trigger.py | 6 +- tests/components/mqtt/test_device_trigger.py | 75 +++++++++++++++++++ tests/components/mqtt/test_trigger.py | 25 +++++++ 4 files changed, 114 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 8969072553c..d6e2ee0fc65 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -13,6 +13,7 @@ from homeassistant.const import ( CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE, + CONF_VALUE_TEMPLATE, ) from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -66,10 +67,11 @@ TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( { vol.Required(CONF_AUTOMATION_TYPE): str, vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, - vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, vol.Optional(CONF_PAYLOAD, default=None): vol.Any(None, cv.string), - vol.Required(CONF_TYPE): cv.string, vol.Required(CONF_SUBTYPE): cv.string, + vol.Required(CONF_TOPIC): cv.string, + vol.Required(CONF_TYPE): cv.string, + vol.Optional(CONF_VALUE_TEMPLATE, default=None): vol.Any(None, cv.string), }, validate_device_has_at_least_one_identifier, ) @@ -96,6 +98,8 @@ class TriggerInstance: } if self.trigger.payload: mqtt_config[CONF_PAYLOAD] = self.trigger.payload + if self.trigger.value_template: + mqtt_config[CONF_VALUE_TEMPLATE] = self.trigger.value_template mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config) if self.remove: @@ -121,6 +125,7 @@ class Trigger: subtype: str = attr.ib() topic: str = attr.ib() type: str = attr.ib() + value_template: str = attr.ib() trigger_instances: List[TriggerInstance] = attr.ib(factory=list) async def add_trigger(self, action, automation_info): @@ -153,6 +158,7 @@ class Trigger: self.qos = config[CONF_QOS] topic_changed = self.topic != config[CONF_TOPIC] self.topic = config[CONF_TOPIC] + self.value_template = config[CONF_VALUE_TEMPLATE] # Unsubscribe+subscribe if this trigger is in use and topic has changed # If topic is same unsubscribe+subscribe will execute in the wrong order @@ -245,6 +251,7 @@ async def async_setup_trigger(hass, config, config_entry, discovery_data): payload=config[CONF_PAYLOAD], qos=config[CONF_QOS], remove_signal=remove_signal, + value_template=config[CONF_VALUE_TEMPLATE], ) else: await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( @@ -325,6 +332,7 @@ async def async_attach_trigger( topic=None, payload=None, qos=None, + value_template=None, ) return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger( action, automation_info diff --git a/homeassistant/components/mqtt/trigger.py b/homeassistant/components/mqtt/trigger.py index 459adabd418..82f7885b85d 100644 --- a/homeassistant/components/mqtt/trigger.py +++ b/homeassistant/components/mqtt/trigger.py @@ -48,11 +48,13 @@ async def async_attach_trigger(hass, config, action, automation_info): template.attach(hass, wanted_payload) if wanted_payload: - wanted_payload = wanted_payload.async_render(variables, limited=True) + wanted_payload = wanted_payload.async_render( + variables, limited=True, parse_result=False + ) template.attach(hass, topic) if isinstance(topic, template.Template): - topic = topic.async_render(variables, limited=True) + topic = topic.async_render(variables, limited=True, parse_result=False) topic = mqtt.util.valid_subscribe_topic(topic) template.attach(hass, value_template) diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py index f200de6a274..210dac19e0c 100644 --- a/tests/components/mqtt/test_device_trigger.py +++ b/tests/components/mqtt/test_device_trigger.py @@ -290,6 +290,81 @@ async def test_if_fires_on_mqtt_message(hass, device_reg, calls, mqtt_mock): assert calls[1].data["some"] == "long_press" +async def test_if_fires_on_mqtt_message_template(hass, device_reg, calls, mqtt_mock): + """Test triggers firing.""" + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + " \"payload\": \"{{ 'foo_press'|regex_replace('foo', 'short') }}\"," + ' "topic": "foobar/triggers/button{{ sqrt(16)|round }}",' + ' "type": "button_short_press",' + ' "subtype": "button_1",' + ' "value_template": "{{ value_json.button }}"}' + ) + data2 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + " \"payload\": \"{{ 'foo_press'|regex_replace('foo', 'long') }}\"," + ' "topic": "foobar/triggers/button{{ sqrt(16)|round }}",' + ' "type": "button_long_press",' + ' "subtype": "button_2",' + ' "value_template": "{{ value_json.button }}"}' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla2/config", data2) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla2", + "type": "button_1", + "subtype": "button_long_press", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("long_press")}, + }, + }, + ] + }, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button4", '{"button":"short_press"}') + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["some"] == "short_press" + + # Fake long press. + async_fire_mqtt_message(hass, "foobar/triggers/button4", '{"button":"long_press"}') + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].data["some"] == "long_press" + + async def test_if_fires_on_mqtt_message_late_discover( hass, device_reg, calls, mqtt_mock ): diff --git a/tests/components/mqtt/test_trigger.py b/tests/components/mqtt/test_trigger.py index 23078b9ba23..d0a86e08655 100644 --- a/tests/components/mqtt/test_trigger.py +++ b/tests/components/mqtt/test_trigger.py @@ -81,6 +81,31 @@ async def test_if_fires_on_topic_and_payload_match(hass, calls): assert len(calls) == 1 +async def test_if_fires_on_topic_and_payload_match2(hass, calls): + """Test if message is fired on topic and payload match. + + Make sure a payload which would render as a non string can still be matched. + """ + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + "platform": "mqtt", + "topic": "test-topic", + "payload": "0", + }, + "action": {"service": "test.automation"}, + } + }, + ) + + async_fire_mqtt_message(hass, "test-topic", "0") + await hass.async_block_till_done() + assert len(calls) == 1 + + async def test_if_fires_on_templated_topic_and_payload_match(hass, calls): """Test if message is fired on templated topic and payload match.""" assert await async_setup_component(