From 81ca978413186d84b80fa9087967aefe82908c65 Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Sat, 25 Feb 2017 02:11:50 +0100 Subject: [PATCH] Move mqtt from eventbus to dispatcher / add unsub for dispatcher (#6206) * Move mqtt from eventbus to dispatcher / add unsub for dispatcher * Fix lint * Fix test * Fix lint v2 * fix dispatcher_send --- homeassistant/components/mqtt/__init__.py | 27 ++++++----- homeassistant/components/mqtt_eventstream.py | 10 ---- homeassistant/helpers/dispatcher.py | 26 ++++++++++- tests/common.py | 8 ++-- tests/components/mqtt/test_init.py | 32 +++++++++---- tests/components/test_mqtt_eventstream.py | 44 ------------------ tests/helpers/test_dispatcher.py | 49 ++++++++++++++++---- 7 files changed, 105 insertions(+), 91 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 57ea0351168..78311623258 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -17,6 +17,8 @@ from homeassistant.bootstrap import async_prepare_setup_platform from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import template, config_validation as cv +from homeassistant.helpers.dispatcher import ( + async_dispatcher_connect, dispatcher_send) from homeassistant.util.async import ( run_coroutine_threadsafe, run_callback_threadsafe) from homeassistant.const import ( @@ -31,7 +33,7 @@ DOMAIN = 'mqtt' DATA_MQTT = 'mqtt' SERVICE_PUBLISH = 'publish' -EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received' +SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received' REQUIREMENTS = ['paho-mqtt==1.2'] @@ -195,16 +197,15 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): """Subscribe to an MQTT topic.""" @callback - def async_mqtt_topic_subscriber(event): + def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos): """Match subscribed MQTT topic.""" - if not _match_topic(topic, event.data[ATTR_TOPIC]): + if not _match_topic(topic, dp_topic): return - hass.async_run_job(msg_callback, event.data[ATTR_TOPIC], - event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) + hass.async_run_job(msg_callback, dp_topic, dp_payload, dp_qos) - async_remove = hass.bus.async_listen( - EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber) + async_remove = async_dispatcher_connect( + hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber) yield from hass.data[DATA_MQTT].async_subscribe(topic, qos) return async_remove @@ -551,13 +552,11 @@ class MQTT(object): "MQTT topic: %s, Payload: %s", msg.topic, msg.payload) else: - _LOGGER.debug("Received message on %s: %s", - msg.topic, payload) - self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, { - ATTR_TOPIC: msg.topic, - ATTR_QOS: msg.qos, - ATTR_PAYLOAD: payload, - }) + _LOGGER.info("Received message on %s: %s", msg.topic, payload) + dispatcher_send( + self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, payload, + msg.qos + ) def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos): """Unsubscribe successful callback.""" diff --git a/homeassistant/components/mqtt_eventstream.py b/homeassistant/components/mqtt_eventstream.py index c4a4b7bc4ab..bd149b6397d 100644 --- a/homeassistant/components/mqtt_eventstream.py +++ b/homeassistant/components/mqtt_eventstream.py @@ -19,7 +19,6 @@ from homeassistant.const import ( from homeassistant.core import EventOrigin, State import homeassistant.helpers.config_validation as cv from homeassistant.remote import JSONEncoder -from .mqtt import EVENT_MQTT_MESSAGE_RECEIVED DOMAIN = "mqtt_eventstream" DEPENDENCIES = ['mqtt'] @@ -54,15 +53,6 @@ def async_setup(hass, config): if event.event_type == EVENT_TIME_CHANGED: return - # MQTT fires a bus event for every incoming message, also messages from - # eventstream. Disable publishing these messages to other HA instances - # and possibly creating an infinite loop if these instances publish - # back to this one. - if all([not conf.get(CONF_PUBLISH_EVENTSTREAM_RECEIVED), - event.event_type == EVENT_MQTT_MESSAGE_RECEIVED, - event.data.get('topic') == sub_topic]): - return - # Filter out the events that were triggered by publishing # to the MQTT topic, or you will end up in an infinite loop. if event.event_type == EVENT_CALL_SERVICE: diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index 324d4ccc621..3a1d7d075aa 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -1,13 +1,24 @@ """Helpers for hass dispatcher & internal component / platform.""" +import logging from homeassistant.core import callback +from homeassistant.util.async import run_callback_threadsafe + +_LOGGER = logging.getLogger(__name__) DATA_DISPATCHER = 'dispatcher' def dispatcher_connect(hass, signal, target): """Connect a callable function to a singal.""" - hass.add_job(async_dispatcher_connect, hass, signal, target) + async_unsub = run_callback_threadsafe( + hass.loop, async_dispatcher_connect, hass, signal, target).result() + + def remove_dispatcher(): + """Remove signal listener.""" + run_callback_threadsafe(hass.loop, async_unsub).result() + + return remove_dispatcher @callback @@ -24,6 +35,19 @@ def async_dispatcher_connect(hass, signal, target): hass.data[DATA_DISPATCHER][signal].append(target) + @callback + def async_remove_dispatcher(): + """Remove signal listener.""" + try: + hass.data[DATA_DISPATCHER][signal].remove(target) + except (KeyError, ValueError): + # KeyError is key target listener did not exist + # ValueError if listener did not exist within signal + _LOGGER.warning( + "Unable to remove unknown dispatcher %s", target) + + return async_remove_dispatcher + def dispatcher_send(hass, signal, *args): """Send signal and data.""" diff --git a/tests/common.py b/tests/common.py index 762531752ca..82623dd0e2d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,6 +14,7 @@ from aiohttp import web from homeassistant import core as ha, loader from homeassistant.bootstrap import ( setup_component, async_prepare_setup_component) +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE from homeassistant.util.unit_system import METRIC_SYSTEM @@ -158,11 +159,8 @@ def mock_service(hass, domain, service): @ha.callback def async_fire_mqtt_message(hass, topic, payload, qos=0): """Fire the MQTT message.""" - hass.bus.async_fire(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, { - mqtt.ATTR_TOPIC: topic, - mqtt.ATTR_PAYLOAD: payload, - mqtt.ATTR_QOS: qos, - }) + async_dispatcher_send( + hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, payload, qos) def fire_mqtt_message(hass, topic, payload, qos=0): diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 18510dd2ff3..255d5f6a96c 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -13,6 +13,7 @@ import homeassistant.components.mqtt as mqtt from homeassistant.const import ( EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP) +from homeassistant.helpers.dispatcher import async_dispatcher_connect from tests.common import ( get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro) @@ -237,11 +238,17 @@ class TestMQTTCallbacks(unittest.TestCase): calls = [] @callback - def record(event): + def record(topic, payload, qos): """Helper to record calls.""" - calls.append(event) + data = { + 'topic': topic, + 'payload': payload, + 'qos': qos, + } + calls.append(data) - self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record) + async_dispatcher_connect( + self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record) MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) @@ -252,9 +259,9 @@ class TestMQTTCallbacks(unittest.TestCase): self.assertEqual(1, len(calls)) last_event = calls[0] - self.assertEqual('Hello World!', last_event.data['payload']) - self.assertEqual(message.topic, last_event.data['topic']) - self.assertEqual(message.qos, last_event.data['qos']) + self.assertEqual('Hello World!', last_event['payload']) + self.assertEqual(message.topic, last_event['topic']) + self.assertEqual(message.qos, last_event['qos']) def test_mqtt_failed_connection_results_in_disconnect(self): """Test if connection failure leads to disconnect.""" @@ -300,13 +307,20 @@ class TestMQTTCallbacks(unittest.TestCase): calls = [] @callback - def record(event): + def record(topic, payload, qos): """Helper to record calls.""" - calls.append(event) + data = { + 'topic': topic, + 'payload': payload, + 'qos': qos, + } + calls.append(data) + + async_dispatcher_connect( + self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record) payload = 0x9a topic = 'test_topic' - self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record) MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) message = MQTTMessage(topic, 1, payload) with self.assertLogs(level='ERROR') as test_handle: diff --git a/tests/components/test_mqtt_eventstream.py b/tests/components/test_mqtt_eventstream.py index c4e7f7fd673..dd08904a8e1 100644 --- a/tests/components/test_mqtt_eventstream.py +++ b/tests/components/test_mqtt_eventstream.py @@ -1,11 +1,9 @@ """The tests for the MQTT eventstream component.""" -from collections import namedtuple import json from unittest.mock import ANY, patch from homeassistant.bootstrap import setup_component import homeassistant.components.mqtt_eventstream as eventstream -import homeassistant.components.mqtt as mqtt from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.core import State, callback from homeassistant.remote import JSONEncoder @@ -146,45 +144,3 @@ class TestMqttEventStream(object): self.hass.block_till_done() assert 1 == len(calls) - - @patch('homeassistant.components.mqtt.async_publish') - def test_mqtt_received_event(self, mock_pub): - """Don't filter events from the mqtt component about received message. - - Mqtt component sends an event if a message is received. Also - messages that originate from an incoming eventstream. - Broadcasting these messages result in an infinite loop if two HA - instances are crossconfigured for the same mqtt topics. - - """ - SUB_TOPIC = 'from_slaves' - assert self.add_eventstream( - pub_topic='bar', - sub_topic=SUB_TOPIC) - self.hass.block_till_done() - - # Reset the mock because it will have already gotten calls for the - # mqtt_eventstream state change on initialization, etc. - mock_pub.reset_mock() - - # Use MQTT component message handler to simulate firing message - # received event. - MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) - message = MQTTMessage( - SUB_TOPIC, 1, '{"test": "Hello World!"}'.encode('utf-8')) - mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message) - - self.hass.block_till_done() - - # 'normal' incoming mqtt messages should be broadcasted - assert mock_pub.call_count == 0 - - MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) - message = MQTTMessage( - 'test_topic', 1, '{"test": "Hello World!"}'.encode('utf-8')) - mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message) - - self.hass.block_till_done() - - # but event from the event stream not - assert mock_pub.call_count == 1 diff --git a/tests/helpers/test_dispatcher.py b/tests/helpers/test_dispatcher.py index fbac0689ff1..066e7386c6e 100644 --- a/tests/helpers/test_dispatcher.py +++ b/tests/helpers/test_dispatcher.py @@ -28,8 +28,6 @@ class TestHelpersDispatcher(object): calls.append(data) dispatcher_connect(self.hass, 'test', test_funct) - self.hass.block_till_done() - dispatcher_send(self.hass, 'test', 3) self.hass.block_till_done() @@ -40,6 +38,47 @@ class TestHelpersDispatcher(object): assert calls == [3, 'bla'] + def test_simple_function_unsub(self): + """Test simple function (executor) and unsub.""" + calls1 = [] + calls2 = [] + + def test_funct1(data): + """Test function.""" + calls1.append(data) + + def test_funct2(data): + """Test function.""" + calls2.append(data) + + dispatcher_connect(self.hass, 'test1', test_funct1) + unsub = dispatcher_connect(self.hass, 'test2', test_funct2) + dispatcher_send(self.hass, 'test1', 3) + dispatcher_send(self.hass, 'test2', 4) + self.hass.block_till_done() + + assert calls1 == [3] + assert calls2 == [4] + + unsub() + + dispatcher_send(self.hass, 'test1', 5) + dispatcher_send(self.hass, 'test2', 6) + self.hass.block_till_done() + + assert calls1 == [3, 5] + assert calls2 == [4] + + # check don't kill the flow + unsub() + + dispatcher_send(self.hass, 'test1', 7) + dispatcher_send(self.hass, 'test2', 8) + self.hass.block_till_done() + + assert calls1 == [3, 5, 7] + assert calls2 == [4] + def test_simple_callback(self): """Test simple callback (async).""" calls = [] @@ -50,8 +89,6 @@ class TestHelpersDispatcher(object): calls.append(data) dispatcher_connect(self.hass, 'test', test_funct) - self.hass.block_till_done() - dispatcher_send(self.hass, 'test', 3) self.hass.block_till_done() @@ -72,8 +109,6 @@ class TestHelpersDispatcher(object): calls.append(data) dispatcher_connect(self.hass, 'test', test_funct) - self.hass.block_till_done() - dispatcher_send(self.hass, 'test', 3) self.hass.block_till_done() @@ -95,8 +130,6 @@ class TestHelpersDispatcher(object): calls.append(data3) dispatcher_connect(self.hass, 'test', test_funct) - self.hass.block_till_done() - dispatcher_send(self.hass, 'test', 3, 2, 'bla') self.hass.block_till_done()