Remove discovered MQTT Switch device when discovery topic is cleared (#16605)

* Remove discovered device when discovery topic is cleared

* Move entity removal away from mqtt discovery

* Move discovery update to mixin class

* Add testcase

* Review comments
This commit is contained in:
emontnemery 2018-09-24 10:11:49 +02:00 committed by Fabian Affolter
parent a5cb4e6c2b
commit 5ee4718e24
4 changed files with 121 additions and 40 deletions

View File

@ -92,6 +92,7 @@ ATTR_PAYLOAD = 'payload'
ATTR_PAYLOAD_TEMPLATE = 'payload_template'
ATTR_QOS = CONF_QOS
ATTR_RETAIN = CONF_RETAIN
ATTR_DISCOVERY_HASH = 'discovery_hash'
MAX_RECONNECT_WAIT = 300 # seconds
@ -833,3 +834,36 @@ class MqttAvailability(Entity):
def available(self) -> bool:
"""Return if the device is available."""
return self._available
class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message."""
def __init__(self, discovery_hash) -> None:
"""Initialize the discovery update mixin."""
self._discovery_hash = discovery_hash
self._remove_signal = None
async def async_added_to_hass(self) -> None:
"""Subscribe to discovery updates."""
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.components.mqtt.discovery import (
ALREADY_DISCOVERED, MQTT_DISCOVERY_UPDATED)
@callback
def discovery_callback(payload):
"""Handle discovery update."""
_LOGGER.info("Got update for entity with hash: %s '%s'",
self._discovery_hash, payload)
if not payload:
# Empty payload: Remove component
_LOGGER.info("Removing component: %s", self.entity_id)
self.hass.async_create_task(self.async_remove())
del self.hass.data[ALREADY_DISCOVERED][self._discovery_hash]
self._remove_signal()
if self._discovery_hash:
self._remove_signal = async_dispatcher_connect(
self.hass,
MQTT_DISCOVERY_UPDATED.format(self._discovery_hash),
discovery_callback)

View File

@ -9,9 +9,10 @@ import logging
import re
from homeassistant.components import mqtt
from homeassistant.components.mqtt import CONF_STATE_TOPIC
from homeassistant.components.mqtt import CONF_STATE_TOPIC, ATTR_DISCOVERY_HASH
from homeassistant.const import CONF_PLATFORM
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.dispatcher import async_dispatcher_send
_LOGGER = logging.getLogger(__name__)
@ -38,6 +39,7 @@ ALLOWED_PLATFORMS = {
}
ALREADY_DISCOVERED = 'mqtt_discovered_components'
MQTT_DISCOVERY_UPDATED = 'mqtt_discovery_updated_{}'
async def async_start(hass, discovery_topic, hass_config):
@ -51,47 +53,53 @@ async def async_start(hass, discovery_topic, hass_config):
_prefix_topic, component, node_id, object_id = match.groups()
try:
payload = json.loads(payload)
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: %s", object_id, payload)
return
if component not in SUPPORTED_COMPONENTS:
_LOGGER.warning("Component %s is not supported", component)
return
payload = dict(payload)
platform = payload.get(CONF_PLATFORM, 'mqtt')
if platform not in ALLOWED_PLATFORMS.get(component, []):
_LOGGER.warning("Platform %s (component %s) is not allowed",
platform, component)
return
payload[CONF_PLATFORM] = platform
if CONF_STATE_TOPIC not in payload:
payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format(
discovery_topic, component, '%s/' % node_id if node_id else '',
object_id)
if ALREADY_DISCOVERED not in hass.data:
hass.data[ALREADY_DISCOVERED] = set()
# If present, the node_id will be included in the discovered object id
discovery_id = '_'.join((node_id, object_id)) if node_id else object_id
if ALREADY_DISCOVERED not in hass.data:
hass.data[ALREADY_DISCOVERED] = {}
discovery_hash = (component, discovery_id)
if discovery_hash in hass.data[ALREADY_DISCOVERED]:
_LOGGER.info("Component has already been discovered: %s %s",
component, discovery_id)
return
_LOGGER.info(
"Component has already been discovered: %s %s, sending update",
component, discovery_id)
async_dispatcher_send(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), payload)
elif payload:
# Add component
try:
payload = json.loads(payload)
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'",
object_id, payload)
return
hass.data[ALREADY_DISCOVERED].add(discovery_hash)
payload = dict(payload)
platform = payload.get(CONF_PLATFORM, 'mqtt')
if platform not in ALLOWED_PLATFORMS.get(component, []):
_LOGGER.warning("Platform %s (component %s) is not allowed",
platform, component)
return
_LOGGER.info("Found new component: %s %s", component, discovery_id)
payload[CONF_PLATFORM] = platform
if CONF_STATE_TOPIC not in payload:
payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format(
discovery_topic, component,
'%s/' % node_id if node_id else '', object_id)
await async_load_platform(
hass, component, platform, payload, hass_config)
hass.data[ALREADY_DISCOVERED][discovery_hash] = None
payload[ATTR_DISCOVERY_HASH] = discovery_hash
_LOGGER.info("Found new component: %s %s", component, discovery_id)
await async_load_platform(
hass, component, platform, payload, hass_config)
await mqtt.async_subscribe(
hass, discovery_topic + '/#', async_device_message_received, 0)

View File

@ -11,9 +11,10 @@ import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.mqtt import (
CONF_STATE_TOPIC, CONF_COMMAND_TOPIC, CONF_AVAILABILITY_TOPIC,
CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN,
MqttAvailability)
ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC, CONF_COMMAND_TOPIC,
CONF_AVAILABILITY_TOPIC, CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, MqttAvailability,
MqttDiscoveryUpdate)
from homeassistant.components.switch import SwitchDevice
from homeassistant.const import (
CONF_NAME, CONF_OPTIMISTIC, CONF_VALUE_TEMPLATE, CONF_PAYLOAD_OFF,
@ -56,7 +57,11 @@ async def async_setup_platform(hass, config, async_add_entities,
if value_template is not None:
value_template.hass = hass
async_add_entities([MqttSwitch(
discovery_hash = None
if discovery_info is not None and ATTR_DISCOVERY_HASH in discovery_info:
discovery_hash = discovery_info[ATTR_DISCOVERY_HASH]
newswitch = MqttSwitch(
config.get(CONF_NAME),
config.get(CONF_ICON),
config.get(CONF_STATE_TOPIC),
@ -73,10 +78,13 @@ async def async_setup_platform(hass, config, async_add_entities,
config.get(CONF_PAYLOAD_NOT_AVAILABLE),
config.get(CONF_UNIQUE_ID),
value_template,
)])
discovery_hash,
)
async_add_entities([newswitch])
class MqttSwitch(MqttAvailability, SwitchDevice):
class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, SwitchDevice):
"""Representation of a switch that can be toggled using MQTT."""
def __init__(self, name, icon,
@ -84,10 +92,11 @@ class MqttSwitch(MqttAvailability, SwitchDevice):
qos, retain, payload_on, payload_off, state_on,
state_off, optimistic, payload_available,
payload_not_available, unique_id: Optional[str],
value_template):
value_template, discovery_hash):
"""Initialize the MQTT switch."""
super().__init__(availability_topic, qos, payload_available,
payload_not_available)
MqttAvailability.__init__(self, availability_topic, qos,
payload_available, payload_not_available)
MqttDiscoveryUpdate.__init__(self, discovery_hash)
self._state = False
self._name = name
self._icon = icon
@ -102,10 +111,12 @@ class MqttSwitch(MqttAvailability, SwitchDevice):
self._optimistic = optimistic
self._template = value_template
self._unique_id = unique_id
self._discovery_hash = discovery_hash
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await super().async_added_to_hass()
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
@callback
def state_message_received(topic, payload, qos):

View File

@ -181,3 +181,31 @@ def test_non_duplicate_discovery(hass, mqtt_mock, caplog):
assert state_duplicate is None
assert 'Component has already been discovered: ' \
'binary_sensor bla' in caplog.text
@asyncio.coroutine
def test_discovery_removal(hass, mqtt_mock, caplog):
"""Test expansion of abbreviated discovery payload."""
yield from async_start(hass, 'homeassistant', {})
data = (
'{ "name": "Beer",'
' "status_topic": "test_topic",'
' "command_topic": "test_topic" }'
)
async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config',
data)
yield from hass.async_block_till_done()
state = hass.states.get('switch.beer')
assert state is not None
assert state.name == 'Beer'
async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config',
'')
yield from hass.async_block_till_done()
yield from hass.async_block_till_done()
state = hass.states.get('switch.beer')
assert state is None