diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 3857184a330..c43593dfc4b 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -19,6 +19,8 @@ from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, from .debug_info import log_messages from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper +CONF_COMMAND_TEMPLATE = "command_template" + _LOGGER = logging.getLogger(__name__) CONF_OPTIONS = "options" @@ -43,6 +45,7 @@ def validate_config(config): _PLATFORM_SCHEMA_BASE = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend( { + vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, vol.Required(CONF_OPTIONS): cv.ensure_list, @@ -110,9 +113,16 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): self._optimistic = config[CONF_OPTIMISTIC] self._attr_options = config[CONF_OPTIONS] - value_template = self._config.get(CONF_VALUE_TEMPLATE) - if value_template is not None: - value_template.hass = self.hass + self._templates = { + CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE), + CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE), + } + for key, tpl in self._templates.items(): + if tpl is None: + self._templates[key] = lambda value: value + else: + tpl.hass = self.hass + self._templates[key] = tpl.async_render_with_possible_json_value async def _subscribe_topics(self): """(Re)Subscribe to topics.""" @@ -121,10 +131,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): @log_messages(self.hass, self.entity_id) def message_received(msg): """Handle new MQTT messages.""" - payload = msg.payload - value_template = self._config.get(CONF_VALUE_TEMPLATE) - if value_template is not None: - payload = value_template.async_render_with_possible_json_value(payload) + payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) if payload.lower() == "none": payload = None @@ -162,6 +169,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): async def async_select_option(self, option: str) -> None: """Update the current value.""" + payload = self._templates[CONF_COMMAND_TEMPLATE](option) if self._optimistic: self._attr_current_option = option self.async_write_ha_state() @@ -169,7 +177,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): await mqtt.async_publish( self.hass, self._config[CONF_COMMAND_TOPIC], - option, + payload, self._config[CONF_QOS], self._config[CONF_RETAIN], ) diff --git a/tests/components/mqtt/test_select.py b/tests/components/mqtt/test_select.py index 4843631f98b..731aca2e178 100644 --- a/tests/components/mqtt/test_select.py +++ b/tests/components/mqtt/test_select.py @@ -171,6 +171,50 @@ async def test_run_select_service_optimistic(hass, mqtt_mock): assert state.state == "beer" +async def test_run_select_service_optimistic_with_command_template(hass, mqtt_mock): + """Test that set_value service works in optimistic mode and with a command_template.""" + topic = "test/select" + + fake_state = ha.State("select.test", "milk") + + with patch( + "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", + return_value=fake_state, + ): + assert await async_setup_component( + hass, + select.DOMAIN, + { + "select": { + "platform": "mqtt", + "command_topic": topic, + "name": "Test Select", + "options": ["milk", "beer"], + "command_template": '{"option": "{{ value }}"}', + } + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("select.test_select") + assert state.state == "milk" + assert state.attributes.get(ATTR_ASSUMED_STATE) + + await hass.services.async_call( + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: "select.test_select", ATTR_OPTION: "beer"}, + blocking=True, + ) + + mqtt_mock.async_publish.assert_called_once_with( + topic, '{"option": "beer"}', 0, False + ) + mqtt_mock.async_publish.reset_mock() + state = hass.states.get("select.test_select") + assert state.state == "beer" + + async def test_run_select_service(hass, mqtt_mock): """Test that set_value service works in non optimistic mode.""" cmd_topic = "test/select/set" @@ -206,6 +250,42 @@ async def test_run_select_service(hass, mqtt_mock): assert state.state == "beer" +async def test_run_select_service_with_command_template(hass, mqtt_mock): + """Test that set_value service works in non optimistic mode and with a command_template.""" + cmd_topic = "test/select/set" + state_topic = "test/select" + + assert await async_setup_component( + hass, + select.DOMAIN, + { + "select": { + "platform": "mqtt", + "command_topic": cmd_topic, + "state_topic": state_topic, + "name": "Test Select", + "options": ["milk", "beer"], + "command_template": '{"option": "{{ value }}"}', + } + }, + ) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, state_topic, "beer") + state = hass.states.get("select.test_select") + assert state.state == "beer" + + await hass.services.async_call( + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: "select.test_select", ATTR_OPTION: "milk"}, + blocking=True, + ) + mqtt_mock.async_publish.assert_called_once_with( + cmd_topic, '{"option": "milk"}', 0, False + ) + + async def test_availability_when_connection_lost(hass, mqtt_mock): """Test availability after MQTT disconnection.""" await help_test_availability_when_connection_lost(