diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 7e93c26a887..fb9b4707ac8 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -1,6 +1,7 @@ """Support for MQTT message handling.""" from __future__ import annotations +from ast import literal_eval import asyncio from dataclasses import dataclass import datetime as dt @@ -250,6 +251,55 @@ MQTT_PUBLISH_SCHEMA = vol.All( SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None +class MqttCommandTemplate: + """Class for rendering MQTT payload with command templates.""" + + def __init__( + self, + command_template: template.Template | None, + hass: HomeAssistant, + ) -> None: + """Instantiate a command template.""" + self._attr_command_template = command_template + if command_template is None: + return + + command_template.hass = hass + + @callback + def async_render( + self, + value: PublishPayloadType = None, + variables: template.TemplateVarsType = None, + ) -> PublishPayloadType: + """Render or convert the command template with given value or variables.""" + + def _convert_outgoing_payload( + payload: PublishPayloadType, + ) -> PublishPayloadType: + """Ensure correct raw MQTT payload is passed as bytes for publishing.""" + if isinstance(payload, str): + try: + native_object = literal_eval(payload) + if isinstance(native_object, bytes): + return native_object + + except (ValueError, TypeError, SyntaxError, MemoryError): + pass + + return payload + + if self._attr_command_template is None: + return value + + values = {"value": value} + if variables is not None: + values.update(variables) + return _convert_outgoing_payload( + self._attr_command_template.async_render(values, parse_result=False) + ) + + @dataclass class MqttServiceInfo(BaseServiceInfo): """Prepared info from mqtt entries.""" @@ -295,7 +345,9 @@ async def async_publish( hass: HomeAssistant, topic: Any, payload, qos=0, retain=False ) -> None: """Publish message to an MQTT topic.""" - await hass.data[DATA_MQTT].async_publish(topic, str(payload), qos, retain) + await hass.data[DATA_MQTT].async_publish( + topic, str(payload) if not isinstance(payload, bytes) else payload, qos, retain + ) AsyncDeprecatedMessageCallbackType = Callable[ @@ -523,9 +575,9 @@ async def async_setup_entry(hass, entry): if payload_template is not None: try: - payload = template.Template(payload_template, hass).async_render( - parse_result=False - ) + payload = MqttCommandTemplate( + template.Template(payload_template), hass + ).async_render() except (template.jinja2.TemplateError, TemplateError) as exc: _LOGGER.error( "Unable to publish to %s: rendering payload template of " diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index 3c324c0789b..5076ceade65 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -34,7 +34,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.typing import ConfigType -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -150,8 +150,9 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): value_template = self._config.get(CONF_VALUE_TEMPLATE) if value_template is not None: value_template.hass = self.hass - command_template = self._config[CONF_COMMAND_TEMPLATE] - command_template.hass = self.hass + self._command_template = MqttCommandTemplate( + self._config[CONF_COMMAND_TEMPLATE], self.hass + ).async_render async def _subscribe_topics(self): """(Re)Subscribe to topics.""" @@ -306,9 +307,8 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): async def _publish(self, code, action): """Publish via mqtt.""" - command_template = self._config[CONF_COMMAND_TEMPLATE] - values = {"action": action, "code": code} - payload = command_template.async_render(**values, parse_result=False) + variables = {"action": action, "code": code} + payload = self._command_template(None, variables=variables) await mqtt.async_publish( self.hass, self._config[CONF_COMMAND_TOPIC], diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index e1f63252495..c1b104440cb 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -52,7 +52,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.typing import ConfigType -from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, subscription +from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_QOS, CONF_RETAIN, DOMAIN from .debug_info import log_messages @@ -377,11 +377,10 @@ class MqttClimate(MqttEntity, ClimateEntity): command_templates = {} for key in COMMAND_TEMPLATE_KEYS: - command_templates[key] = lambda value: value - for key in COMMAND_TEMPLATE_KEYS & config.keys(): - tpl = config[key] - command_templates[key] = tpl.async_render_with_possible_json_value - tpl.hass = self.hass + command_templates[key] = MqttCommandTemplate( + config.get(key), self.hass + ).async_render + self._command_templates = command_templates async def _subscribe_topics(self): # noqa: C901 diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index c2800cc8239..8b8a0764aca 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -36,7 +36,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.typing import ConfigType -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -288,17 +288,17 @@ class MqttCover(MqttEntity, CoverEntity): if value_template is not None: value_template.hass = self.hass - set_position_template = self._config.get(CONF_SET_POSITION_TEMPLATE) - if set_position_template is not None: - set_position_template.hass = self.hass + self._set_position_template = MqttCommandTemplate( + self._config.get(CONF_SET_POSITION_TEMPLATE), self.hass + ).async_render get_position_template = self._config.get(CONF_GET_POSITION_TEMPLATE) if get_position_template is not None: get_position_template.hass = self.hass - set_tilt_template = self._config.get(CONF_TILT_COMMAND_TEMPLATE) - if set_tilt_template is not None: - set_tilt_template.hass = self.hass + self._set_tilt_template = MqttCommandTemplate( + self._config.get(CONF_TILT_COMMAND_TEMPLATE), self.hass + ).async_render tilt_status_template = self._config.get(CONF_TILT_STATUS_TEMPLATE) if tilt_status_template is not None: @@ -611,21 +611,19 @@ class MqttCover(MqttEntity, CoverEntity): async def async_set_cover_tilt_position(self, **kwargs): """Move the cover tilt to a specific position.""" - template = self._config.get(CONF_TILT_COMMAND_TEMPLATE) tilt = kwargs[ATTR_TILT_POSITION] percentage_tilt = tilt tilt = self.find_in_range_from_percent(tilt) # Handover the tilt after calculated from percent would make it more consistent with receiving templates - if template is not None: - variables = { - "tilt_position": percentage_tilt, - "entity_id": self.entity_id, - "position_open": self._config[CONF_POSITION_OPEN], - "position_closed": self._config[CONF_POSITION_CLOSED], - "tilt_min": self._config[CONF_TILT_MIN], - "tilt_max": self._config[CONF_TILT_MAX], - } - tilt = template.async_render(parse_result=False, variables=variables) + variables = { + "tilt_position": percentage_tilt, + "entity_id": self.entity_id, + "position_open": self._config.get(CONF_POSITION_OPEN), + "position_closed": self._config.get(CONF_POSITION_CLOSED), + "tilt_min": self._config.get(CONF_TILT_MIN), + "tilt_max": self._config.get(CONF_TILT_MAX), + } + tilt = self._set_tilt_template(tilt, variables=variables) await mqtt.async_publish( self.hass, @@ -641,20 +639,18 @@ class MqttCover(MqttEntity, CoverEntity): async def async_set_cover_position(self, **kwargs): """Move the cover to a specific position.""" - template = self._config.get(CONF_SET_POSITION_TEMPLATE) position = kwargs[ATTR_POSITION] percentage_position = position position = self.find_in_range_from_percent(position, COVER_PAYLOAD) - if template is not None: - variables = { - "position": percentage_position, - "entity_id": self.entity_id, - "position_open": self._config[CONF_POSITION_OPEN], - "position_closed": self._config[CONF_POSITION_CLOSED], - "tilt_min": self._config[CONF_TILT_MIN], - "tilt_max": self._config[CONF_TILT_MAX], - } - position = template.async_render(parse_result=False, variables=variables) + variables = { + "position": percentage_position, + "entity_id": self.entity_id, + "position_open": self._config[CONF_POSITION_OPEN], + "position_closed": self._config[CONF_POSITION_CLOSED], + "tilt_min": self._config[CONF_TILT_MIN], + "tilt_max": self._config[CONF_TILT_MAX], + } + position = self._set_position_template(position, variables=variables) await mqtt.async_publish( self.hass, diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index 9d0c954f3ab..aecf94bdd42 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -36,7 +36,7 @@ from homeassistant.util.percentage import ( ranged_value_to_percentage, ) -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -332,13 +332,17 @@ class MqttFan(MqttEntity, FanEntity): if self._feature_preset_mode: self._supported_features |= SUPPORT_PRESET_MODE - for tpl_dict in (self._command_templates, self._value_templates): - for key, tpl in tpl_dict.items(): - if tpl is None: - tpl_dict[key] = lambda value: value - else: - tpl.hass = self.hass - tpl_dict[key] = tpl.async_render_with_possible_json_value + for key, tpl in self._command_templates.items(): + self._command_templates[key] = MqttCommandTemplate( + tpl, self.hass + ).async_render + + for key, tpl in self._value_templates.items(): + if tpl is None: + self._value_templates[key] = lambda value: value + else: + tpl.hass = self.hass + self._value_templates[key] = tpl.async_render_with_possible_json_value async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index a5346c0bf58..fcafa185509 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -27,7 +27,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.typing import ConfigType -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -237,13 +237,17 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): ) self._optimistic_mode = optimistic or self._topic[CONF_MODE_STATE_TOPIC] is None - for tpl_dict in (self._command_templates, self._value_templates): - for key, tpl in tpl_dict.items(): - if tpl is None: - tpl_dict[key] = lambda value: value - else: - tpl.hass = self.hass - tpl_dict[key] = tpl.async_render_with_possible_json_value + for key, tpl in self._command_templates.items(): + self._command_templates[key] = MqttCommandTemplate( + tpl, self.hass + ).async_render + + for key, tpl in self._value_templates.items(): + if tpl is None: + self._value_templates[key] = lambda value: value + else: + tpl.hass = self.hass + self._value_templates[key] = tpl.async_render_with_possible_json_value async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 29f6b4bad21..2c2c23606d4 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -25,7 +25,7 @@ from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -138,15 +138,20 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): self._optimistic = config[CONF_OPTIMISTIC] self._templates = { - CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE), + CONF_COMMAND_TEMPLATE: MqttCommandTemplate( + config.get(CONF_COMMAND_TEMPLATE), self.hass + ).async_render, CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE), } - for key, tpl in self._templates.items(): - if tpl is None: - self._templates[key] = lambda value: value - else: - tpl.hass = self.hass - self._templates[key] = tpl.async_render_with_possible_json_value + + value_template = self._templates[CONF_VALUE_TEMPLATE] + if value_template is None: + self._templates[CONF_VALUE_TEMPLATE] = lambda value: value + else: + value_template.hass = self.hass + self._templates[ + CONF_VALUE_TEMPLATE + ] = value_template.async_render_with_possible_json_value async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 7b374ba8955..e51800953c0 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -13,7 +13,7 @@ from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType -from . import PLATFORMS, subscription +from . import PLATFORMS, MqttCommandTemplate, subscription from .. import mqtt from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN from .debug_info import log_messages @@ -102,15 +102,20 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): self._attr_options = config[CONF_OPTIONS] self._templates = { - CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE), + CONF_COMMAND_TEMPLATE: MqttCommandTemplate( + config.get(CONF_COMMAND_TEMPLATE), self.hass + ).async_render, CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE), } - for key, tpl in self._templates.items(): - if tpl is None: - self._templates[key] = lambda value: value - else: - tpl.hass = self.hass - self._templates[key] = tpl.async_render_with_possible_json_value + + value_template = self._templates[CONF_VALUE_TEMPLATE] + if value_template is None: + self._templates[CONF_VALUE_TEMPLATE] = lambda value: value + else: + value_template.hass = self.hass + self._templates[ + CONF_VALUE_TEMPLATE + ] = value_template.async_render_with_possible_json_value async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index a4a3c1d6909..32528881d64 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -18,7 +18,7 @@ from homeassistant.const import ( ) from homeassistant.core import CoreState, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import device_registry as dr, template from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow @@ -91,7 +91,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(hass, mqtt_mock): assert mqtt_mock.async_disconnect.called -async def test_publish_(hass, mqtt_mock): +async def test_publish(hass, mqtt_mock): """Test the publish function.""" await mqtt.async_publish(hass, "test-topic", "test-payload") await hass.async_block_till_done() @@ -137,6 +137,57 @@ async def test_publish_(hass, mqtt_mock): ) mqtt_mock.reset_mock() + # test binary pass-through + mqtt.publish( + hass, + "test-topic3", + b"\xde\xad\xbe\xef", + 0, + False, + ) + await hass.async_block_till_done() + assert mqtt_mock.async_publish.called + assert mqtt_mock.async_publish.call_args[0] == ( + "test-topic3", + b"\xde\xad\xbe\xef", + 0, + False, + ) + mqtt_mock.reset_mock() + + +async def test_convert_outgoing_payload(hass): + """Test the converting of outgoing MQTT payloads without template.""" + command_template = mqtt.MqttCommandTemplate(None, hass) + assert command_template.async_render(b"\xde\xad\xbe\xef") == b"\xde\xad\xbe\xef" + + assert ( + command_template.async_render("b'\\xde\\xad\\xbe\\xef'") + == "b'\\xde\\xad\\xbe\\xef'" + ) + + assert command_template.async_render(1234) == 1234 + + assert command_template.async_render(1234.56) == 1234.56 + + assert command_template.async_render(None) is None + + +async def test_command_template_value(hass): + """Test the rendering of MQTT command template.""" + + variables = {"id": 1234, "some_var": "beer"} + + # test rendering value + tpl = template.Template("{{ value + 1 }}", hass) + cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass) + assert cmd_tpl.async_render(4321) == "4322" + + # test variables at rendering + tpl = template.Template("{{ some_var }}", hass) + cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass) + assert cmd_tpl.async_render(None, variables=variables) == "beer" + async def test_service_call_without_topic_does_not_publish(hass, mqtt_mock): """Test the service call if topic is missing.""" @@ -260,6 +311,20 @@ async def test_service_call_with_template_payload_renders_template(hass, mqtt_mo ) assert mqtt_mock.async_publish.called assert mqtt_mock.async_publish.call_args[0][1] == "8" + mqtt_mock.reset_mock() + + await hass.services.async_call( + mqtt.DOMAIN, + mqtt.SERVICE_PUBLISH, + { + mqtt.ATTR_TOPIC: "test/topic", + mqtt.ATTR_PAYLOAD_TEMPLATE: "{{ (4+4) | pack('B') }}", + }, + blocking=True, + ) + assert mqtt_mock.async_publish.called + assert mqtt_mock.async_publish.call_args[0][1] == b"\x08" + mqtt_mock.reset_mock() async def test_service_call_with_bad_template(hass, mqtt_mock):