From 5ee4718e24bf0a0e44ab4cd5f930fdb8aa400c53 Mon Sep 17 00:00:00 2001 From: emontnemery Date: Mon, 24 Sep 2018 10:11:49 +0200 Subject: [PATCH] Remove discovered MQTT Switch device when discovery topic is cleared (#16605) * Remove discovered device when discovery topic is cleared * Move entity removal away from mqtt discovery * Move discovery update to mixin class * Add testcase * Review comments --- homeassistant/components/mqtt/__init__.py | 34 +++++++++++ homeassistant/components/mqtt/discovery.py | 68 ++++++++++++---------- homeassistant/components/switch/mqtt.py | 31 ++++++---- tests/components/mqtt/test_discovery.py | 28 +++++++++ 4 files changed, 121 insertions(+), 40 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index bcb0d60902b..abc240a65cb 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -92,6 +92,7 @@ ATTR_PAYLOAD = 'payload' ATTR_PAYLOAD_TEMPLATE = 'payload_template' ATTR_QOS = CONF_QOS ATTR_RETAIN = CONF_RETAIN +ATTR_DISCOVERY_HASH = 'discovery_hash' MAX_RECONNECT_WAIT = 300 # seconds @@ -833,3 +834,36 @@ class MqttAvailability(Entity): def available(self) -> bool: """Return if the device is available.""" return self._available + + +class MqttDiscoveryUpdate(Entity): + """Mixin used to handle updated discovery message.""" + + def __init__(self, discovery_hash) -> None: + """Initialize the discovery update mixin.""" + self._discovery_hash = discovery_hash + self._remove_signal = None + + async def async_added_to_hass(self) -> None: + """Subscribe to discovery updates.""" + from homeassistant.helpers.dispatcher import async_dispatcher_connect + from homeassistant.components.mqtt.discovery import ( + ALREADY_DISCOVERED, MQTT_DISCOVERY_UPDATED) + + @callback + def discovery_callback(payload): + """Handle discovery update.""" + _LOGGER.info("Got update for entity with hash: %s '%s'", + self._discovery_hash, payload) + if not payload: + # Empty payload: Remove component + _LOGGER.info("Removing component: %s", self.entity_id) + self.hass.async_create_task(self.async_remove()) + del self.hass.data[ALREADY_DISCOVERED][self._discovery_hash] + self._remove_signal() + + if self._discovery_hash: + self._remove_signal = async_dispatcher_connect( + self.hass, + MQTT_DISCOVERY_UPDATED.format(self._discovery_hash), + discovery_callback) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 689515f64c8..f42c1ed58e9 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -9,9 +9,10 @@ import logging import re from homeassistant.components import mqtt -from homeassistant.components.mqtt import CONF_STATE_TOPIC +from homeassistant.components.mqtt import CONF_STATE_TOPIC, ATTR_DISCOVERY_HASH from homeassistant.const import CONF_PLATFORM from homeassistant.helpers.discovery import async_load_platform +from homeassistant.helpers.dispatcher import async_dispatcher_send _LOGGER = logging.getLogger(__name__) @@ -38,6 +39,7 @@ ALLOWED_PLATFORMS = { } ALREADY_DISCOVERED = 'mqtt_discovered_components' +MQTT_DISCOVERY_UPDATED = 'mqtt_discovery_updated_{}' async def async_start(hass, discovery_topic, hass_config): @@ -51,47 +53,53 @@ async def async_start(hass, discovery_topic, hass_config): _prefix_topic, component, node_id, object_id = match.groups() - try: - payload = json.loads(payload) - except ValueError: - _LOGGER.warning("Unable to parse JSON %s: %s", object_id, payload) - return - if component not in SUPPORTED_COMPONENTS: _LOGGER.warning("Component %s is not supported", component) return - payload = dict(payload) - platform = payload.get(CONF_PLATFORM, 'mqtt') - if platform not in ALLOWED_PLATFORMS.get(component, []): - _LOGGER.warning("Platform %s (component %s) is not allowed", - platform, component) - return - - payload[CONF_PLATFORM] = platform - if CONF_STATE_TOPIC not in payload: - payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format( - discovery_topic, component, '%s/' % node_id if node_id else '', - object_id) - - if ALREADY_DISCOVERED not in hass.data: - hass.data[ALREADY_DISCOVERED] = set() - # If present, the node_id will be included in the discovered object id discovery_id = '_'.join((node_id, object_id)) if node_id else object_id + if ALREADY_DISCOVERED not in hass.data: + hass.data[ALREADY_DISCOVERED] = {} + discovery_hash = (component, discovery_id) + if discovery_hash in hass.data[ALREADY_DISCOVERED]: - _LOGGER.info("Component has already been discovered: %s %s", - component, discovery_id) - return + _LOGGER.info( + "Component has already been discovered: %s %s, sending update", + component, discovery_id) + async_dispatcher_send( + hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), payload) + elif payload: + # Add component + try: + payload = json.loads(payload) + except ValueError: + _LOGGER.warning("Unable to parse JSON %s: '%s'", + object_id, payload) + return - hass.data[ALREADY_DISCOVERED].add(discovery_hash) + payload = dict(payload) + platform = payload.get(CONF_PLATFORM, 'mqtt') + if platform not in ALLOWED_PLATFORMS.get(component, []): + _LOGGER.warning("Platform %s (component %s) is not allowed", + platform, component) + return - _LOGGER.info("Found new component: %s %s", component, discovery_id) + payload[CONF_PLATFORM] = platform + if CONF_STATE_TOPIC not in payload: + payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format( + discovery_topic, component, + '%s/' % node_id if node_id else '', object_id) - await async_load_platform( - hass, component, platform, payload, hass_config) + hass.data[ALREADY_DISCOVERED][discovery_hash] = None + payload[ATTR_DISCOVERY_HASH] = discovery_hash + + _LOGGER.info("Found new component: %s %s", component, discovery_id) + + await async_load_platform( + hass, component, platform, payload, hass_config) await mqtt.async_subscribe( hass, discovery_topic + '/#', async_device_message_received, 0) diff --git a/homeassistant/components/switch/mqtt.py b/homeassistant/components/switch/mqtt.py index f6075d5e49f..b79f8f12b87 100644 --- a/homeassistant/components/switch/mqtt.py +++ b/homeassistant/components/switch/mqtt.py @@ -11,9 +11,10 @@ import voluptuous as vol from homeassistant.core import callback from homeassistant.components.mqtt import ( - CONF_STATE_TOPIC, CONF_COMMAND_TOPIC, CONF_AVAILABILITY_TOPIC, - CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, - MqttAvailability) + ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC, CONF_COMMAND_TOPIC, + CONF_AVAILABILITY_TOPIC, CONF_PAYLOAD_AVAILABLE, + CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, MqttAvailability, + MqttDiscoveryUpdate) from homeassistant.components.switch import SwitchDevice from homeassistant.const import ( CONF_NAME, CONF_OPTIMISTIC, CONF_VALUE_TEMPLATE, CONF_PAYLOAD_OFF, @@ -56,7 +57,11 @@ async def async_setup_platform(hass, config, async_add_entities, if value_template is not None: value_template.hass = hass - async_add_entities([MqttSwitch( + discovery_hash = None + if discovery_info is not None and ATTR_DISCOVERY_HASH in discovery_info: + discovery_hash = discovery_info[ATTR_DISCOVERY_HASH] + + newswitch = MqttSwitch( config.get(CONF_NAME), config.get(CONF_ICON), config.get(CONF_STATE_TOPIC), @@ -73,10 +78,13 @@ async def async_setup_platform(hass, config, async_add_entities, config.get(CONF_PAYLOAD_NOT_AVAILABLE), config.get(CONF_UNIQUE_ID), value_template, - )]) + discovery_hash, + ) + + async_add_entities([newswitch]) -class MqttSwitch(MqttAvailability, SwitchDevice): +class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, SwitchDevice): """Representation of a switch that can be toggled using MQTT.""" def __init__(self, name, icon, @@ -84,10 +92,11 @@ class MqttSwitch(MqttAvailability, SwitchDevice): qos, retain, payload_on, payload_off, state_on, state_off, optimistic, payload_available, payload_not_available, unique_id: Optional[str], - value_template): + value_template, discovery_hash): """Initialize the MQTT switch.""" - super().__init__(availability_topic, qos, payload_available, - payload_not_available) + MqttAvailability.__init__(self, availability_topic, qos, + payload_available, payload_not_available) + MqttDiscoveryUpdate.__init__(self, discovery_hash) self._state = False self._name = name self._icon = icon @@ -102,10 +111,12 @@ class MqttSwitch(MqttAvailability, SwitchDevice): self._optimistic = optimistic self._template = value_template self._unique_id = unique_id + self._discovery_hash = discovery_hash async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await super().async_added_to_hass() + await MqttAvailability.async_added_to_hass(self) + await MqttDiscoveryUpdate.async_added_to_hass(self) @callback def state_message_received(topic, payload, qos): diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 9e0ef14a3fa..6de277eb48d 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -181,3 +181,31 @@ def test_non_duplicate_discovery(hass, mqtt_mock, caplog): assert state_duplicate is None assert 'Component has already been discovered: ' \ 'binary_sensor bla' in caplog.text + + +@asyncio.coroutine +def test_discovery_removal(hass, mqtt_mock, caplog): + """Test expansion of abbreviated discovery payload.""" + yield from async_start(hass, 'homeassistant', {}) + + data = ( + '{ "name": "Beer",' + ' "status_topic": "test_topic",' + ' "command_topic": "test_topic" }' + ) + + async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config', + data) + yield from hass.async_block_till_done() + + state = hass.states.get('switch.beer') + assert state is not None + assert state.name == 'Beer' + + async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config', + '') + yield from hass.async_block_till_done() + yield from hass.async_block_till_done() + + state = hass.states.get('switch.beer') + assert state is None