From 60ae85564e9ec0ccd89357099b6ff92f6cacecd5 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 18 Feb 2020 22:51:10 +0100 Subject: [PATCH] Add support for MQTT device triggers (#31679) * Add support for MQTT device triggers * Fix test, tweaks * Improve test coverage * Address review comments, improve tests * Tidy up exception handling * Fix abbreviations * Rewrite to handle update of attached triggers * Update abbreviation test * Refactor according to review comments * Refactor according to review comments * Improve trigger removal * Further refactoring --- homeassistant/components/automation/mqtt.py | 8 +- homeassistant/components/mqtt/__init__.py | 55 +- .../components/mqtt/abbreviations.py | 3 + homeassistant/components/mqtt/cover.py | 5 +- .../components/mqtt/device_automation.py | 44 + .../components/mqtt/device_trigger.py | 273 ++++++ homeassistant/components/mqtt/discovery.py | 14 +- homeassistant/components/mqtt/strings.json | 22 + tests/components/mqtt/test_device_trigger.py | 777 ++++++++++++++++++ tests/components/mqtt/test_discovery.py | 3 +- 10 files changed, 1170 insertions(+), 34 deletions(-) create mode 100644 homeassistant/components/mqtt/device_automation.py create mode 100644 homeassistant/components/mqtt/device_trigger.py create mode 100644 tests/components/mqtt/test_device_trigger.py diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index fb0073c78d5..046cbba2873 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -11,8 +11,10 @@ import homeassistant.helpers.config_validation as cv # mypy: allow-untyped-defs CONF_ENCODING = "encoding" +CONF_QOS = "qos" CONF_TOPIC = "topic" DEFAULT_ENCODING = "utf-8" +DEFAULT_QOS = 0 TRIGGER_SCHEMA = vol.Schema( { @@ -20,6 +22,9 @@ TRIGGER_SCHEMA = vol.Schema( vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, vol.Optional(CONF_PAYLOAD): cv.string, vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, + vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( + vol.Coerce(int), vol.In([0, 1, 2]) + ), } ) @@ -29,6 +34,7 @@ async def async_attach_trigger(hass, config, action, automation_info): topic = config[CONF_TOPIC] payload = config.get(CONF_PAYLOAD) encoding = config[CONF_ENCODING] or None + qos = config[CONF_QOS] @callback def mqtt_automation_listener(mqttmsg): @@ -49,6 +55,6 @@ async def async_attach_trigger(hass, config, action, automation_info): hass.async_run_job(action, {"trigger": data}) remove = await mqtt.async_subscribe( - hass, topic, mqtt_automation_listener, encoding=encoding + hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos ) return remove diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 6780a33c7d7..540d09d7c9f 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -1194,6 +1194,34 @@ class MqttDiscoveryUpdate(Entity): ) +def device_info_from_config(config): + """Return a device description for device registry.""" + if not config: + return None + + info = { + "identifiers": {(DOMAIN, id_) for id_ in config[CONF_IDENTIFIERS]}, + "connections": {tuple(x) for x in config[CONF_CONNECTIONS]}, + } + + if CONF_MANUFACTURER in config: + info["manufacturer"] = config[CONF_MANUFACTURER] + + if CONF_MODEL in config: + info["model"] = config[CONF_MODEL] + + if CONF_NAME in config: + info["name"] = config[CONF_NAME] + + if CONF_SW_VERSION in config: + info["sw_version"] = config[CONF_SW_VERSION] + + if CONF_VIA_DEVICE in config: + info["via_device"] = (DOMAIN, config[CONF_VIA_DEVICE]) + + return info + + class MqttEntityDeviceInfo(Entity): """Mixin used for mqtt platforms that support the device registry.""" @@ -1216,32 +1244,7 @@ class MqttEntityDeviceInfo(Entity): @property def device_info(self): """Return a device description for device registry.""" - if not self._device_config: - return None - - info = { - "identifiers": { - (DOMAIN, id_) for id_ in self._device_config[CONF_IDENTIFIERS] - }, - "connections": {tuple(x) for x in self._device_config[CONF_CONNECTIONS]}, - } - - if CONF_MANUFACTURER in self._device_config: - info["manufacturer"] = self._device_config[CONF_MANUFACTURER] - - if CONF_MODEL in self._device_config: - info["model"] = self._device_config[CONF_MODEL] - - if CONF_NAME in self._device_config: - info["name"] = self._device_config[CONF_NAME] - - if CONF_SW_VERSION in self._device_config: - info["sw_version"] = self._device_config[CONF_SW_VERSION] - - if CONF_VIA_DEVICE in self._device_config: - info["via_device"] = (DOMAIN, self._device_config[CONF_VIA_DEVICE]) - - return info + return device_info_from_config(self._device_config) @websocket_api.async_response diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index acbc2731846..6cfab66c3f1 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -3,6 +3,7 @@ ABBREVIATIONS = { "act_t": "action_topic", "act_tpl": "action_template", + "atype": "automation_type", "aux_cmd_t": "aux_command_topic", "aux_stat_tpl": "aux_state_template", "aux_stat_t": "aux_state_topic", @@ -80,6 +81,7 @@ ABBREVIATIONS = { "osc_cmd_t": "oscillation_command_topic", "osc_stat_t": "oscillation_state_topic", "osc_val_tpl": "oscillation_value_template", + "pl": "payload", "pl_arm_away": "payload_arm_away", "pl_arm_home": "payload_arm_home", "pl_arm_nite": "payload_arm_night", @@ -142,6 +144,7 @@ ABBREVIATIONS = { "stat_t": "state_topic", "stat_tpl": "state_template", "stat_val_tpl": "state_value_template", + "stype": "subtype", "sup_feat": "supported_features", "swing_mode_cmd_t": "swing_mode_command_topic", "swing_mode_stat_tpl": "swing_mode_state_template", diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 4f2f29f94fb..885343b7090 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -178,15 +178,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add an MQTT cover.""" + discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( config, async_add_entities, config_entry, discovery_hash ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_hash) raise async_dispatcher_connect( diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py new file mode 100644 index 00000000000..3f0889875d0 --- /dev/null +++ b/homeassistant/components/mqtt/device_automation.py @@ -0,0 +1,44 @@ +"""Provides device automations for MQTT.""" +import logging + +import voluptuous as vol + +from homeassistant.components import mqtt +from homeassistant.helpers.dispatcher import async_dispatcher_connect + +from . import ATTR_DISCOVERY_HASH, device_trigger +from .discovery import MQTT_DISCOVERY_NEW, clear_discovery_hash + +_LOGGER = logging.getLogger(__name__) + +AUTOMATION_TYPE_TRIGGER = "trigger" +AUTOMATION_TYPES = [AUTOMATION_TYPE_TRIGGER] +AUTOMATION_TYPES_SCHEMA = vol.In(AUTOMATION_TYPES) +CONF_AUTOMATION_TYPE = "automation_type" + +PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( + {vol.Required(CONF_AUTOMATION_TYPE): AUTOMATION_TYPES_SCHEMA}, + extra=vol.ALLOW_EXTRA, +) + + +async def async_setup_entry(hass, config_entry): + """Set up MQTT device automation dynamically through MQTT discovery.""" + + async def async_discover(discovery_payload): + """Discover and add an MQTT device automation.""" + discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) + try: + config = PLATFORM_SCHEMA(discovery_payload) + if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER: + await device_trigger.async_setup_trigger( + hass, config, config_entry, discovery_hash + ) + except Exception: + if discovery_hash: + clear_discovery_hash(hass, discovery_hash) + raise + + async_dispatcher_connect( + hass, MQTT_DISCOVERY_NEW.format("device_automation", "mqtt"), async_discover + ) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py new file mode 100644 index 00000000000..2149024266d --- /dev/null +++ b/homeassistant/components/mqtt/device_trigger.py @@ -0,0 +1,273 @@ +"""Provides device automations for MQTT.""" +import logging +from typing import List + +import attr +import voluptuous as vol + +from homeassistant.components import mqtt +from homeassistant.components.automation import AutomationActionType +import homeassistant.components.automation.mqtt as automation_mqtt +from homeassistant.components.device_automation import TRIGGER_BASE_SCHEMA +from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.typing import ConfigType, HomeAssistantType + +from . import ( + ATTR_DISCOVERY_HASH, + CONF_CONNECTIONS, + CONF_DEVICE, + CONF_IDENTIFIERS, + CONF_PAYLOAD, + CONF_QOS, + DOMAIN, +) +from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash + +_LOGGER = logging.getLogger(__name__) + +CONF_AUTOMATION_TYPE = "automation_type" +CONF_DISCOVERY_ID = "discovery_id" +CONF_SUBTYPE = "subtype" +CONF_TOPIC = "topic" +DEFAULT_ENCODING = "utf-8" +DEVICE = "device" + +MQTT_TRIGGER_BASE = { + # Trigger when MQTT message is received + CONF_PLATFORM: DEVICE, + CONF_DOMAIN: DOMAIN, +} + +TRIGGER_SCHEMA = TRIGGER_BASE_SCHEMA.extend( + { + vol.Required(CONF_PLATFORM): DEVICE, + vol.Required(CONF_DOMAIN): DOMAIN, + vol.Required(CONF_DEVICE_ID): str, + vol.Required(CONF_DISCOVERY_ID): str, + vol.Required(CONF_TYPE): cv.string, + vol.Required(CONF_SUBTYPE): cv.string, + } +) + +TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_AUTOMATION_TYPE): str, + vol.Required(CONF_DEVICE): mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA, + vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_PAYLOAD, default=None): vol.Any(None, cv.string), + vol.Required(CONF_TYPE): cv.string, + vol.Required(CONF_SUBTYPE): cv.string, + }, + mqtt.validate_device_has_at_least_one_identifier, +) + +DEVICE_TRIGGERS = "mqtt_device_triggers" + + +@attr.s(slots=True) +class TriggerInstance: + """Attached trigger settings.""" + + action = attr.ib(type=AutomationActionType) + automation_info = attr.ib(type=dict) + trigger = attr.ib(type="Trigger") + remove = attr.ib(type=CALLBACK_TYPE, default=None) + + async def async_attach_trigger(self): + """Attach MQTT trigger.""" + mqtt_config = { + automation_mqtt.CONF_TOPIC: self.trigger.topic, + automation_mqtt.CONF_ENCODING: DEFAULT_ENCODING, + automation_mqtt.CONF_QOS: self.trigger.qos, + } + if self.trigger.payload: + mqtt_config[CONF_PAYLOAD] = self.trigger.payload + + if self.remove: + self.remove() + self.remove = await automation_mqtt.async_attach_trigger( + self.trigger.hass, mqtt_config, self.action, self.automation_info, + ) + + +@attr.s(slots=True) +class Trigger: + """Device trigger settings.""" + + device_id = attr.ib(type=str) + hass = attr.ib(type=HomeAssistantType) + payload = attr.ib(type=str) + qos = attr.ib(type=int) + subtype = attr.ib(type=str) + topic = attr.ib(type=str) + type = attr.ib(type=str) + trigger_instances = attr.ib(type=[TriggerInstance], default=attr.Factory(list)) + + async def add_trigger(self, action, automation_info): + """Add MQTT trigger.""" + instance = TriggerInstance(action, automation_info, self) + self.trigger_instances.append(instance) + + if self.topic is not None: + # If we know about the trigger, subscribe to MQTT topic + await instance.async_attach_trigger() + + @callback + def async_remove() -> None: + """Remove trigger.""" + if instance not in self.trigger_instances: + raise HomeAssistantError("Can't remove trigger twice") + + if instance.remove: + instance.remove() + self.trigger_instances.remove(instance) + + return async_remove + + async def update_trigger(self, config): + """Update MQTT device trigger.""" + self.type = config[CONF_TYPE] + self.subtype = config[CONF_SUBTYPE] + self.topic = config[CONF_TOPIC] + self.payload = config[CONF_PAYLOAD] + self.qos = config[CONF_QOS] + + # Unsubscribe+subscribe if this trigger is in use + for trig in self.trigger_instances: + await trig.async_attach_trigger() + + def detach_trigger(self): + """Remove MQTT device trigger.""" + # Mark trigger as unknown + + self.topic = None + # Unsubscribe if this trigger is in use + for trig in self.trigger_instances: + if trig.remove: + trig.remove() + trig.remove = None + + +async def _update_device(hass, config_entry, config): + """Update device registry.""" + device_registry = await hass.helpers.device_registry.async_get_registry() + config_entry_id = config_entry.entry_id + device_info = mqtt.device_info_from_config(config[CONF_DEVICE]) + + if config_entry_id is not None and device_info is not None: + device_info["config_entry_id"] = config_entry_id + device_registry.async_get_or_create(**device_info) + + +async def async_setup_trigger(hass, config, config_entry, discovery_hash): + """Set up the MQTT device trigger.""" + config = TRIGGER_DISCOVERY_SCHEMA(config) + discovery_id = discovery_hash[1] + remove_signal = None + + async def discovery_update(payload): + """Handle discovery update.""" + _LOGGER.info( + "Got update for trigger with hash: %s '%s'", discovery_hash, payload + ) + if not payload: + # Empty payload: Remove trigger + _LOGGER.info("Removing trigger: %s", discovery_hash) + if discovery_id in hass.data[DEVICE_TRIGGERS]: + device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] + device_trigger.detach_trigger() + clear_discovery_hash(hass, discovery_hash) + remove_signal() + else: + # Non-empty payload: Update trigger + _LOGGER.info("Updating trigger: %s", discovery_hash) + payload.pop(ATTR_DISCOVERY_HASH) + config = TRIGGER_DISCOVERY_SCHEMA(payload) + await _update_device(hass, config_entry, config) + device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] + await device_trigger.update_trigger(config) + + remove_signal = async_dispatcher_connect( + hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update + ) + + await _update_device(hass, config_entry, config) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device = device_registry.async_get_device( + {(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]}, + {tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]}, + ) + + if device is None: + return + + if DEVICE_TRIGGERS not in hass.data: + hass.data[DEVICE_TRIGGERS] = {} + if discovery_id not in hass.data[DEVICE_TRIGGERS]: + hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( + hass=hass, + device_id=device.id, + type=config[CONF_TYPE], + subtype=config[CONF_SUBTYPE], + topic=config[CONF_TOPIC], + payload=config[CONF_PAYLOAD], + qos=config[CONF_QOS], + ) + else: + await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(config) + + +async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]: + """List device triggers for MQTT devices.""" + triggers = [] + + if DEVICE_TRIGGERS not in hass.data: + return triggers + + for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items(): + if trig.device_id != device_id or trig.topic is None: + continue + + trigger = { + **MQTT_TRIGGER_BASE, + "device_id": device_id, + "type": trig.type, + "subtype": trig.subtype, + "discovery_id": discovery_id, + } + triggers.append(trigger) + + return triggers + + +async def async_attach_trigger( + hass: HomeAssistant, + config: ConfigType, + action: AutomationActionType, + automation_info: dict, +) -> CALLBACK_TYPE: + """Attach a trigger.""" + if DEVICE_TRIGGERS not in hass.data: + hass.data[DEVICE_TRIGGERS] = {} + config = TRIGGER_SCHEMA(config) + device_id = config[CONF_DEVICE_ID] + discovery_id = config[CONF_DISCOVERY_ID] + + if discovery_id not in hass.data[DEVICE_TRIGGERS]: + hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( + hass=hass, + device_id=device_id, + type=config[CONF_TYPE], + subtype=config[CONF_SUBTYPE], + topic=None, + payload=None, + qos=None, + ) + return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger( + action, automation_info + ) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index f393c315793..418f648564d 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -26,6 +26,7 @@ SUPPORTED_COMPONENTS = [ "camera", "climate", "cover", + "device_automation", "fan", "light", "lock", @@ -40,6 +41,7 @@ CONFIG_ENTRY_COMPONENTS = [ "camera", "climate", "cover", + "device_automation", "fan", "light", "lock", @@ -197,9 +199,15 @@ async def async_start( config_entries_key = "{}.{}".format(component, "mqtt") async with hass.data[DATA_CONFIG_ENTRY_LOCK]: if config_entries_key not in hass.data[CONFIG_ENTRY_IS_SETUP]: - await hass.config_entries.async_forward_entry_setup( - config_entry, component - ) + if component == "device_automation": + # Local import to avoid circular dependencies + from . import device_automation + + await device_automation.async_setup_entry(hass, config_entry) + else: + await hass.config_entries.async_forward_entry_setup( + config_entry, component + ) hass.data[CONFIG_ENTRY_IS_SETUP].add(config_entries_key) async_dispatcher_send( diff --git a/homeassistant/components/mqtt/strings.json b/homeassistant/components/mqtt/strings.json index 8bacfa530bd..f0a38bcbc55 100644 --- a/homeassistant/components/mqtt/strings.json +++ b/homeassistant/components/mqtt/strings.json @@ -27,5 +27,27 @@ "error": { "cannot_connect": "Unable to connect to the broker." } + }, + "device_automation": { + "trigger_type": { + "button_short_press": "\"{subtype}\" pressed", + "button_short_release": "\"{subtype}\" released", + "button_long_press": "\"{subtype}\" continuously pressed", + "button_long_release": "\"{subtype}\" released after long press", + "button_double_press": "\"{subtype}\" double clicked", + "button_triple_press": "\"{subtype}\" triple clicked", + "button_quadruple_press": "\"{subtype}\" quadruple clicked", + "button_quintuple_press": "\"{subtype}\" quintuple clicked" + }, + "trigger_subtype": { + "turn_on": "Turn on", + "turn_off": "Turn off", + "button_1": "First button", + "button_2": "Second button", + "button_3": "Third button", + "button_4": "Fourth button", + "button_5": "Fifth button", + "button_6": "Sixth button" + } } } diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py new file mode 100644 index 00000000000..c3ba6eebadd --- /dev/null +++ b/tests/components/mqtt/test_device_trigger.py @@ -0,0 +1,777 @@ +"""The tests for MQTT device triggers.""" +import json + +import pytest + +import homeassistant.components.automation as automation +from homeassistant.components.mqtt import DOMAIN +from homeassistant.components.mqtt.device_trigger import async_attach_trigger +from homeassistant.components.mqtt.discovery import async_start +from homeassistant.setup import async_setup_component + +from tests.common import ( + MockConfigEntry, + assert_lists_same, + async_fire_mqtt_message, + async_get_device_automations, + async_mock_service, + mock_device_registry, + mock_registry, +) + + +@pytest.fixture +def device_reg(hass): + """Return an empty, loaded, registry.""" + return mock_device_registry(hass) + + +@pytest.fixture +def entity_reg(hass): + """Return an empty, loaded, registry.""" + return mock_registry(hass) + + +@pytest.fixture +def calls(hass): + """Track calls to a mock service.""" + return async_mock_service(hass, "test", "automation") + + +async def test_get_triggers(hass, device_reg, entity_reg, mqtt_mock): + """Test we get the expected triggers from a discovered mqtt device.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + expected_triggers = [ + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla", + "type": "button_short_press", + "subtype": "button_1", + }, + ] + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, expected_triggers) + + +async def test_get_unknown_triggers(hass, device_reg, entity_reg, mqtt_mock): + """Test we don't get unknown triggers.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + # Discover a sensor (without device triggers) + data1 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data1) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + ] + }, + ) + + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, []) + + +async def test_get_non_existing_triggers(hass, device_reg, entity_reg, mqtt_mock): + """Test getting non existing triggers.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + # Discover a sensor (without device triggers) + data1 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data1) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, []) + + +async def test_discover_bad_triggers(hass, device_reg, entity_reg, mqtt_mock): + """Test bad discovery message.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + # Test sending bad data + data0 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payloads": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data0) + await hass.async_block_till_done() + assert device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) is None + + # Test sending correct data + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + expected_triggers = [ + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla", + "type": "button_short_press", + "subtype": "button_1", + }, + ] + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, expected_triggers) + + +async def test_update_remove_triggers(hass, device_reg, entity_reg, mqtt_mock): + """Test triggers can be updated and removed.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + data2 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_2" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + expected_triggers1 = [ + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla", + "type": "button_short_press", + "subtype": "button_1", + }, + ] + expected_triggers2 = [dict(expected_triggers1[0])] + expected_triggers2[0]["subtype"] = "button_2" + + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, expected_triggers1) + + # Update trigger + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data2) + await hass.async_block_till_done() + + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, expected_triggers2) + + # Remove trigger + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", "") + await hass.async_block_till_done() + + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert_lists_same(triggers, []) + + +async def test_if_fires_on_mqtt_message(hass, device_reg, calls, mqtt_mock): + """Test triggers firing.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + data2 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "long_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_long_press",' + ' "subtype": "button_2" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla2/config", data2) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla2", + "type": "button_1", + "subtype": "button_long_press", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("long_press")}, + }, + }, + ] + }, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["some"] == "short_press" + + # Fake long press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "long_press") + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].data["some"] == "long_press" + + +async def test_if_fires_on_mqtt_message_late_discover( + hass, device_reg, calls, mqtt_mock +): + """Test triggers firing of MQTT device triggers discovered after setup.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data0 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + data2 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "long_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_long_press",' + ' "subtype": "button_2" }' + ) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla0/config", data0) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla2", + "type": "button_1", + "subtype": "button_long_press", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("long_press")}, + }, + }, + ] + }, + ) + + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla2/config", data2) + await hass.async_block_till_done() + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["some"] == "short_press" + + # Fake long press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "long_press") + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].data["some"] == "long_press" + + +async def test_if_fires_on_mqtt_message_after_update( + hass, device_reg, calls, mqtt_mock +): + """Test triggers firing after update.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + data2 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "topic": "foobar/triggers/buttonOne",' + ' "type": "button_long_press",' + ' "subtype": "button_2" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + ] + }, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Update the trigger + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data2) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "foobar/triggers/button1", "") + await hass.async_block_till_done() + assert len(calls) == 1 + + async_fire_mqtt_message(hass, "foobar/triggers/buttonOne", "") + await hass.async_block_till_done() + assert len(calls) == 2 + + +async def test_not_fires_on_mqtt_message_after_remove( + hass, device_reg, calls, mqtt_mock +): + """Test triggers not firing after removal.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + ] + }, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Remove the trigger + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", "") + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Rediscover the trigger + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 2 + + +async def test_attach_remove(hass, device_reg, mqtt_mock): + """Test attach and removal of trigger.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + calls = [] + + def callback(trigger): + calls.append(trigger["trigger"]["payload"]) + + remove = await async_attach_trigger( + hass, + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + callback, + None, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0] == "short_press" + + # Remove the trigger + remove() + await hass.async_block_till_done() + + # Verify the triggers are no longer active + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_attach_remove_late(hass, device_reg, mqtt_mock): + """Test attach and removal of trigger .""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data0 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla0/config", data0) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + calls = [] + + def callback(trigger): + calls.append(trigger["trigger"]["payload"]) + + remove = await async_attach_trigger( + hass, + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + callback, + None, + ) + + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0] == "short_press" + + # Remove the trigger + remove() + await hass.async_block_till_done() + + # Verify the triggers are no longer active + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_attach_remove_late2(hass, device_reg, mqtt_mock): + """Test attach and removal of trigger .""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data0 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla0/config", data0) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + calls = [] + + def callback(trigger): + calls.append(trigger["trigger"]["payload"]) + + remove = await async_attach_trigger( + hass, + { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + callback, + None, + ) + + # Remove the trigger + remove() + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + + # Verify the triggers are no longer active + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 0 + + +async def test_entity_device_info_with_identifier(hass, mqtt_mock): + """Test MQTT device registry integration.""" + entry = MockConfigEntry(domain=DOMAIN) + entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, entry) + registry = await hass.helpers.device_registry.async_get_registry() + + data = json.dumps( + { + "automation_type": "trigger", + "topic": "test-topic", + "type": "foo", + "subtype": "bar", + "device": { + "identifiers": ["helloworld"], + "connections": [["mac", "02:5b:26:a8:dc:12"]], + "manufacturer": "Whatever", + "name": "Beer", + "model": "Glass", + "sw_version": "0.1-beta", + }, + } + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data) + await hass.async_block_till_done() + + device = registry.async_get_device({("mqtt", "helloworld")}, set()) + assert device is not None + assert device.identifiers == {("mqtt", "helloworld")} + assert device.connections == {("mac", "02:5b:26:a8:dc:12")} + assert device.manufacturer == "Whatever" + assert device.name == "Beer" + assert device.model == "Glass" + assert device.sw_version == "0.1-beta" + + +async def test_entity_device_info_update(hass, mqtt_mock): + """Test device registry update.""" + entry = MockConfigEntry(domain=DOMAIN) + entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, entry) + registry = await hass.helpers.device_registry.async_get_registry() + + config = { + "automation_type": "trigger", + "topic": "test-topic", + "type": "foo", + "subtype": "bar", + "device": { + "identifiers": ["helloworld"], + "connections": [["mac", "02:5b:26:a8:dc:12"]], + "manufacturer": "Whatever", + "name": "Beer", + "model": "Glass", + "sw_version": "0.1-beta", + }, + } + + data = json.dumps(config) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data) + await hass.async_block_till_done() + + device = registry.async_get_device({("mqtt", "helloworld")}, set()) + assert device is not None + assert device.name == "Beer" + + config["device"]["name"] = "Milk" + data = json.dumps(config) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data) + await hass.async_block_till_done() + + device = registry.async_get_device({("mqtt", "helloworld")}, set()) + assert device is not None + assert device.name == "Milk" diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index e8da9b53a5e..e09b4d786a6 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -250,7 +250,7 @@ async def test_discovery_expansion(hass, mqtt_mock, caplog): ABBREVIATIONS_WHITE_LIST = [ - # MQTT client/server settings + # MQTT client/server/trigger settings "CONF_BIRTH_MESSAGE", "CONF_BROKER", "CONF_CERTIFICATE", @@ -258,6 +258,7 @@ ABBREVIATIONS_WHITE_LIST = [ "CONF_CLIENT_ID", "CONF_CLIENT_KEY", "CONF_DISCOVERY", + "CONF_DISCOVERY_ID", "CONF_DISCOVERY_PREFIX", "CONF_EMBEDDED", "CONF_KEEPALIVE",