Move mqtt from eventbus to dispatcher / add unsub for dispatcher (#6206)

* Move mqtt from eventbus to dispatcher / add unsub for dispatcher

* Fix lint

* Fix test

* Fix lint v2

* fix dispatcher_send
This commit is contained in:
Pascal Vizeli 2017-02-25 02:11:50 +01:00 committed by Paulus Schoutsen
parent d6818c7015
commit 81ca978413
7 changed files with 105 additions and 91 deletions

View File

@ -17,6 +17,8 @@ from homeassistant.bootstrap import async_prepare_setup_platform
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, dispatcher_send)
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
from homeassistant.const import (
@ -31,7 +33,7 @@ DOMAIN = 'mqtt'
DATA_MQTT = 'mqtt'
SERVICE_PUBLISH = 'publish'
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
REQUIREMENTS = ['paho-mqtt==1.2']
@ -195,16 +197,15 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
@callback
def async_mqtt_topic_subscriber(event):
def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos):
"""Match subscribed MQTT topic."""
if not _match_topic(topic, event.data[ATTR_TOPIC]):
if not _match_topic(topic, dp_topic):
return
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
hass.async_run_job(msg_callback, dp_topic, dp_payload, dp_qos)
async_remove = hass.bus.async_listen(
EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
async_remove = async_dispatcher_connect(
hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
return async_remove
@ -551,13 +552,11 @@ class MQTT(object):
"MQTT topic: %s, Payload: %s", msg.topic,
msg.payload)
else:
_LOGGER.debug("Received message on %s: %s",
msg.topic, payload)
self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, {
ATTR_TOPIC: msg.topic,
ATTR_QOS: msg.qos,
ATTR_PAYLOAD: payload,
})
_LOGGER.info("Received message on %s: %s", msg.topic, payload)
dispatcher_send(
self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, payload,
msg.qos
)
def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Unsubscribe successful callback."""

View File

@ -19,7 +19,6 @@ from homeassistant.const import (
from homeassistant.core import EventOrigin, State
import homeassistant.helpers.config_validation as cv
from homeassistant.remote import JSONEncoder
from .mqtt import EVENT_MQTT_MESSAGE_RECEIVED
DOMAIN = "mqtt_eventstream"
DEPENDENCIES = ['mqtt']
@ -54,15 +53,6 @@ def async_setup(hass, config):
if event.event_type == EVENT_TIME_CHANGED:
return
# MQTT fires a bus event for every incoming message, also messages from
# eventstream. Disable publishing these messages to other HA instances
# and possibly creating an infinite loop if these instances publish
# back to this one.
if all([not conf.get(CONF_PUBLISH_EVENTSTREAM_RECEIVED),
event.event_type == EVENT_MQTT_MESSAGE_RECEIVED,
event.data.get('topic') == sub_topic]):
return
# Filter out the events that were triggered by publishing
# to the MQTT topic, or you will end up in an infinite loop.
if event.event_type == EVENT_CALL_SERVICE:

View File

@ -1,13 +1,24 @@
"""Helpers for hass dispatcher & internal component / platform."""
import logging
from homeassistant.core import callback
from homeassistant.util.async import run_callback_threadsafe
_LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = 'dispatcher'
def dispatcher_connect(hass, signal, target):
"""Connect a callable function to a singal."""
hass.add_job(async_dispatcher_connect, hass, signal, target)
async_unsub = run_callback_threadsafe(
hass.loop, async_dispatcher_connect, hass, signal, target).result()
def remove_dispatcher():
"""Remove signal listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove_dispatcher
@callback
@ -24,6 +35,19 @@ def async_dispatcher_connect(hass, signal, target):
hass.data[DATA_DISPATCHER][signal].append(target)
@callback
def async_remove_dispatcher():
"""Remove signal listener."""
try:
hass.data[DATA_DISPATCHER][signal].remove(target)
except (KeyError, ValueError):
# KeyError is key target listener did not exist
# ValueError if listener did not exist within signal
_LOGGER.warning(
"Unable to remove unknown dispatcher %s", target)
return async_remove_dispatcher
def dispatcher_send(hass, signal, *args):
"""Send signal and data."""

View File

@ -14,6 +14,7 @@ from aiohttp import web
from homeassistant import core as ha, loader
from homeassistant.bootstrap import (
setup_component, async_prepare_setup_component)
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
from homeassistant.util.unit_system import METRIC_SYSTEM
@ -158,11 +159,8 @@ def mock_service(hass, domain, service):
@ha.callback
def async_fire_mqtt_message(hass, topic, payload, qos=0):
"""Fire the MQTT message."""
hass.bus.async_fire(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, {
mqtt.ATTR_TOPIC: topic,
mqtt.ATTR_PAYLOAD: payload,
mqtt.ATTR_QOS: qos,
})
async_dispatcher_send(
hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, payload, qos)
def fire_mqtt_message(hass, topic, payload, qos=0):

View File

@ -13,6 +13,7 @@ import homeassistant.components.mqtt as mqtt
from homeassistant.const import (
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from tests.common import (
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro)
@ -237,11 +238,17 @@ class TestMQTTCallbacks(unittest.TestCase):
calls = []
@callback
def record(event):
def record(topic, payload, qos):
"""Helper to record calls."""
calls.append(event)
data = {
'topic': topic,
'payload': payload,
'qos': qos,
}
calls.append(data)
self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record)
async_dispatcher_connect(
self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record)
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
@ -252,9 +259,9 @@ class TestMQTTCallbacks(unittest.TestCase):
self.assertEqual(1, len(calls))
last_event = calls[0]
self.assertEqual('Hello World!', last_event.data['payload'])
self.assertEqual(message.topic, last_event.data['topic'])
self.assertEqual(message.qos, last_event.data['qos'])
self.assertEqual('Hello World!', last_event['payload'])
self.assertEqual(message.topic, last_event['topic'])
self.assertEqual(message.qos, last_event['qos'])
def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect."""
@ -300,13 +307,20 @@ class TestMQTTCallbacks(unittest.TestCase):
calls = []
@callback
def record(event):
def record(topic, payload, qos):
"""Helper to record calls."""
calls.append(event)
data = {
'topic': topic,
'payload': payload,
'qos': qos,
}
calls.append(data)
async_dispatcher_connect(
self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record)
payload = 0x9a
topic = 'test_topic'
self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record)
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(topic, 1, payload)
with self.assertLogs(level='ERROR') as test_handle:

View File

@ -1,11 +1,9 @@
"""The tests for the MQTT eventstream component."""
from collections import namedtuple
import json
from unittest.mock import ANY, patch
from homeassistant.bootstrap import setup_component
import homeassistant.components.mqtt_eventstream as eventstream
import homeassistant.components.mqtt as mqtt
from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.core import State, callback
from homeassistant.remote import JSONEncoder
@ -146,45 +144,3 @@ class TestMqttEventStream(object):
self.hass.block_till_done()
assert 1 == len(calls)
@patch('homeassistant.components.mqtt.async_publish')
def test_mqtt_received_event(self, mock_pub):
"""Don't filter events from the mqtt component about received message.
Mqtt component sends an event if a message is received. Also
messages that originate from an incoming eventstream.
Broadcasting these messages result in an infinite loop if two HA
instances are crossconfigured for the same mqtt topics.
"""
SUB_TOPIC = 'from_slaves'
assert self.add_eventstream(
pub_topic='bar',
sub_topic=SUB_TOPIC)
self.hass.block_till_done()
# Reset the mock because it will have already gotten calls for the
# mqtt_eventstream state change on initialization, etc.
mock_pub.reset_mock()
# Use MQTT component message handler to simulate firing message
# received event.
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(
SUB_TOPIC, 1, '{"test": "Hello World!"}'.encode('utf-8'))
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
self.hass.block_till_done()
# 'normal' incoming mqtt messages should be broadcasted
assert mock_pub.call_count == 0
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(
'test_topic', 1, '{"test": "Hello World!"}'.encode('utf-8'))
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
self.hass.block_till_done()
# but event from the event stream not
assert mock_pub.call_count == 1

View File

@ -28,8 +28,6 @@ class TestHelpersDispatcher(object):
calls.append(data)
dispatcher_connect(self.hass, 'test', test_funct)
self.hass.block_till_done()
dispatcher_send(self.hass, 'test', 3)
self.hass.block_till_done()
@ -40,6 +38,47 @@ class TestHelpersDispatcher(object):
assert calls == [3, 'bla']
def test_simple_function_unsub(self):
"""Test simple function (executor) and unsub."""
calls1 = []
calls2 = []
def test_funct1(data):
"""Test function."""
calls1.append(data)
def test_funct2(data):
"""Test function."""
calls2.append(data)
dispatcher_connect(self.hass, 'test1', test_funct1)
unsub = dispatcher_connect(self.hass, 'test2', test_funct2)
dispatcher_send(self.hass, 'test1', 3)
dispatcher_send(self.hass, 'test2', 4)
self.hass.block_till_done()
assert calls1 == [3]
assert calls2 == [4]
unsub()
dispatcher_send(self.hass, 'test1', 5)
dispatcher_send(self.hass, 'test2', 6)
self.hass.block_till_done()
assert calls1 == [3, 5]
assert calls2 == [4]
# check don't kill the flow
unsub()
dispatcher_send(self.hass, 'test1', 7)
dispatcher_send(self.hass, 'test2', 8)
self.hass.block_till_done()
assert calls1 == [3, 5, 7]
assert calls2 == [4]
def test_simple_callback(self):
"""Test simple callback (async)."""
calls = []
@ -50,8 +89,6 @@ class TestHelpersDispatcher(object):
calls.append(data)
dispatcher_connect(self.hass, 'test', test_funct)
self.hass.block_till_done()
dispatcher_send(self.hass, 'test', 3)
self.hass.block_till_done()
@ -72,8 +109,6 @@ class TestHelpersDispatcher(object):
calls.append(data)
dispatcher_connect(self.hass, 'test', test_funct)
self.hass.block_till_done()
dispatcher_send(self.hass, 'test', 3)
self.hass.block_till_done()
@ -95,8 +130,6 @@ class TestHelpersDispatcher(object):
calls.append(data3)
dispatcher_connect(self.hass, 'test', test_funct)
self.hass.block_till_done()
dispatcher_send(self.hass, 'test', 3, 2, 'bla')
self.hass.block_till_done()