mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Move mqtt from eventbus to dispatcher / add unsub for dispatcher (#6206)
* Move mqtt from eventbus to dispatcher / add unsub for dispatcher * Fix lint * Fix test * Fix lint v2 * fix dispatcher_send
This commit is contained in:
parent
d6818c7015
commit
81ca978413
@ -17,6 +17,8 @@ from homeassistant.bootstrap import async_prepare_setup_platform
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import template, config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect, dispatcher_send)
|
||||
from homeassistant.util.async import (
|
||||
run_coroutine_threadsafe, run_callback_threadsafe)
|
||||
from homeassistant.const import (
|
||||
@ -31,7 +33,7 @@ DOMAIN = 'mqtt'
|
||||
DATA_MQTT = 'mqtt'
|
||||
|
||||
SERVICE_PUBLISH = 'publish'
|
||||
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
|
||||
SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
|
||||
|
||||
REQUIREMENTS = ['paho-mqtt==1.2']
|
||||
|
||||
@ -195,16 +197,15 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
||||
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
|
||||
"""Subscribe to an MQTT topic."""
|
||||
@callback
|
||||
def async_mqtt_topic_subscriber(event):
|
||||
def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos):
|
||||
"""Match subscribed MQTT topic."""
|
||||
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||
if not _match_topic(topic, dp_topic):
|
||||
return
|
||||
|
||||
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
|
||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
|
||||
hass.async_run_job(msg_callback, dp_topic, dp_payload, dp_qos)
|
||||
|
||||
async_remove = hass.bus.async_listen(
|
||||
EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
|
||||
async_remove = async_dispatcher_connect(
|
||||
hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
|
||||
|
||||
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
|
||||
return async_remove
|
||||
@ -551,13 +552,11 @@ class MQTT(object):
|
||||
"MQTT topic: %s, Payload: %s", msg.topic,
|
||||
msg.payload)
|
||||
else:
|
||||
_LOGGER.debug("Received message on %s: %s",
|
||||
msg.topic, payload)
|
||||
self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, {
|
||||
ATTR_TOPIC: msg.topic,
|
||||
ATTR_QOS: msg.qos,
|
||||
ATTR_PAYLOAD: payload,
|
||||
})
|
||||
_LOGGER.info("Received message on %s: %s", msg.topic, payload)
|
||||
dispatcher_send(
|
||||
self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, payload,
|
||||
msg.qos
|
||||
)
|
||||
|
||||
def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos):
|
||||
"""Unsubscribe successful callback."""
|
||||
|
@ -19,7 +19,6 @@ from homeassistant.const import (
|
||||
from homeassistant.core import EventOrigin, State
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.remote import JSONEncoder
|
||||
from .mqtt import EVENT_MQTT_MESSAGE_RECEIVED
|
||||
|
||||
DOMAIN = "mqtt_eventstream"
|
||||
DEPENDENCIES = ['mqtt']
|
||||
@ -54,15 +53,6 @@ def async_setup(hass, config):
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
# MQTT fires a bus event for every incoming message, also messages from
|
||||
# eventstream. Disable publishing these messages to other HA instances
|
||||
# and possibly creating an infinite loop if these instances publish
|
||||
# back to this one.
|
||||
if all([not conf.get(CONF_PUBLISH_EVENTSTREAM_RECEIVED),
|
||||
event.event_type == EVENT_MQTT_MESSAGE_RECEIVED,
|
||||
event.data.get('topic') == sub_topic]):
|
||||
return
|
||||
|
||||
# Filter out the events that were triggered by publishing
|
||||
# to the MQTT topic, or you will end up in an infinite loop.
|
||||
if event.event_type == EVENT_CALL_SERVICE:
|
||||
|
@ -1,13 +1,24 @@
|
||||
"""Helpers for hass dispatcher & internal component / platform."""
|
||||
import logging
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_DISPATCHER = 'dispatcher'
|
||||
|
||||
|
||||
def dispatcher_connect(hass, signal, target):
|
||||
"""Connect a callable function to a singal."""
|
||||
hass.add_job(async_dispatcher_connect, hass, signal, target)
|
||||
async_unsub = run_callback_threadsafe(
|
||||
hass.loop, async_dispatcher_connect, hass, signal, target).result()
|
||||
|
||||
def remove_dispatcher():
|
||||
"""Remove signal listener."""
|
||||
run_callback_threadsafe(hass.loop, async_unsub).result()
|
||||
|
||||
return remove_dispatcher
|
||||
|
||||
|
||||
@callback
|
||||
@ -24,6 +35,19 @@ def async_dispatcher_connect(hass, signal, target):
|
||||
|
||||
hass.data[DATA_DISPATCHER][signal].append(target)
|
||||
|
||||
@callback
|
||||
def async_remove_dispatcher():
|
||||
"""Remove signal listener."""
|
||||
try:
|
||||
hass.data[DATA_DISPATCHER][signal].remove(target)
|
||||
except (KeyError, ValueError):
|
||||
# KeyError is key target listener did not exist
|
||||
# ValueError if listener did not exist within signal
|
||||
_LOGGER.warning(
|
||||
"Unable to remove unknown dispatcher %s", target)
|
||||
|
||||
return async_remove_dispatcher
|
||||
|
||||
|
||||
def dispatcher_send(hass, signal, *args):
|
||||
"""Send signal and data."""
|
||||
|
@ -14,6 +14,7 @@ from aiohttp import web
|
||||
from homeassistant import core as ha, loader
|
||||
from homeassistant.bootstrap import (
|
||||
setup_component, async_prepare_setup_component)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
@ -158,11 +159,8 @@ def mock_service(hass, domain, service):
|
||||
@ha.callback
|
||||
def async_fire_mqtt_message(hass, topic, payload, qos=0):
|
||||
"""Fire the MQTT message."""
|
||||
hass.bus.async_fire(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, {
|
||||
mqtt.ATTR_TOPIC: topic,
|
||||
mqtt.ATTR_PAYLOAD: payload,
|
||||
mqtt.ATTR_QOS: qos,
|
||||
})
|
||||
async_dispatcher_send(
|
||||
hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, payload, qos)
|
||||
|
||||
|
||||
def fire_mqtt_message(hass, topic, payload, qos=0):
|
||||
|
@ -13,6 +13,7 @@ import homeassistant.components.mqtt as mqtt
|
||||
from homeassistant.const import (
|
||||
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
|
||||
EVENT_HOMEASSISTANT_STOP)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro)
|
||||
@ -237,11 +238,17 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def record(event):
|
||||
def record(topic, payload, qos):
|
||||
"""Helper to record calls."""
|
||||
calls.append(event)
|
||||
data = {
|
||||
'topic': topic,
|
||||
'payload': payload,
|
||||
'qos': qos,
|
||||
}
|
||||
calls.append(data)
|
||||
|
||||
self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record)
|
||||
async_dispatcher_connect(
|
||||
self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record)
|
||||
|
||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
|
||||
@ -252,9 +259,9 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||
|
||||
self.assertEqual(1, len(calls))
|
||||
last_event = calls[0]
|
||||
self.assertEqual('Hello World!', last_event.data['payload'])
|
||||
self.assertEqual(message.topic, last_event.data['topic'])
|
||||
self.assertEqual(message.qos, last_event.data['qos'])
|
||||
self.assertEqual('Hello World!', last_event['payload'])
|
||||
self.assertEqual(message.topic, last_event['topic'])
|
||||
self.assertEqual(message.qos, last_event['qos'])
|
||||
|
||||
def test_mqtt_failed_connection_results_in_disconnect(self):
|
||||
"""Test if connection failure leads to disconnect."""
|
||||
@ -300,13 +307,20 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def record(event):
|
||||
def record(topic, payload, qos):
|
||||
"""Helper to record calls."""
|
||||
calls.append(event)
|
||||
data = {
|
||||
'topic': topic,
|
||||
'payload': payload,
|
||||
'qos': qos,
|
||||
}
|
||||
calls.append(data)
|
||||
|
||||
async_dispatcher_connect(
|
||||
self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record)
|
||||
|
||||
payload = 0x9a
|
||||
topic = 'test_topic'
|
||||
self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record)
|
||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||
message = MQTTMessage(topic, 1, payload)
|
||||
with self.assertLogs(level='ERROR') as test_handle:
|
||||
|
@ -1,11 +1,9 @@
|
||||
"""The tests for the MQTT eventstream component."""
|
||||
from collections import namedtuple
|
||||
import json
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from homeassistant.bootstrap import setup_component
|
||||
import homeassistant.components.mqtt_eventstream as eventstream
|
||||
import homeassistant.components.mqtt as mqtt
|
||||
from homeassistant.const import EVENT_STATE_CHANGED
|
||||
from homeassistant.core import State, callback
|
||||
from homeassistant.remote import JSONEncoder
|
||||
@ -146,45 +144,3 @@ class TestMqttEventStream(object):
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert 1 == len(calls)
|
||||
|
||||
@patch('homeassistant.components.mqtt.async_publish')
|
||||
def test_mqtt_received_event(self, mock_pub):
|
||||
"""Don't filter events from the mqtt component about received message.
|
||||
|
||||
Mqtt component sends an event if a message is received. Also
|
||||
messages that originate from an incoming eventstream.
|
||||
Broadcasting these messages result in an infinite loop if two HA
|
||||
instances are crossconfigured for the same mqtt topics.
|
||||
|
||||
"""
|
||||
SUB_TOPIC = 'from_slaves'
|
||||
assert self.add_eventstream(
|
||||
pub_topic='bar',
|
||||
sub_topic=SUB_TOPIC)
|
||||
self.hass.block_till_done()
|
||||
|
||||
# Reset the mock because it will have already gotten calls for the
|
||||
# mqtt_eventstream state change on initialization, etc.
|
||||
mock_pub.reset_mock()
|
||||
|
||||
# Use MQTT component message handler to simulate firing message
|
||||
# received event.
|
||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||
message = MQTTMessage(
|
||||
SUB_TOPIC, 1, '{"test": "Hello World!"}'.encode('utf-8'))
|
||||
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
# 'normal' incoming mqtt messages should be broadcasted
|
||||
assert mock_pub.call_count == 0
|
||||
|
||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||
message = MQTTMessage(
|
||||
'test_topic', 1, '{"test": "Hello World!"}'.encode('utf-8'))
|
||||
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
# but event from the event stream not
|
||||
assert mock_pub.call_count == 1
|
||||
|
@ -28,8 +28,6 @@ class TestHelpersDispatcher(object):
|
||||
calls.append(data)
|
||||
|
||||
dispatcher_connect(self.hass, 'test', test_funct)
|
||||
self.hass.block_till_done()
|
||||
|
||||
dispatcher_send(self.hass, 'test', 3)
|
||||
self.hass.block_till_done()
|
||||
|
||||
@ -40,6 +38,47 @@ class TestHelpersDispatcher(object):
|
||||
|
||||
assert calls == [3, 'bla']
|
||||
|
||||
def test_simple_function_unsub(self):
|
||||
"""Test simple function (executor) and unsub."""
|
||||
calls1 = []
|
||||
calls2 = []
|
||||
|
||||
def test_funct1(data):
|
||||
"""Test function."""
|
||||
calls1.append(data)
|
||||
|
||||
def test_funct2(data):
|
||||
"""Test function."""
|
||||
calls2.append(data)
|
||||
|
||||
dispatcher_connect(self.hass, 'test1', test_funct1)
|
||||
unsub = dispatcher_connect(self.hass, 'test2', test_funct2)
|
||||
dispatcher_send(self.hass, 'test1', 3)
|
||||
dispatcher_send(self.hass, 'test2', 4)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert calls1 == [3]
|
||||
assert calls2 == [4]
|
||||
|
||||
unsub()
|
||||
|
||||
dispatcher_send(self.hass, 'test1', 5)
|
||||
dispatcher_send(self.hass, 'test2', 6)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert calls1 == [3, 5]
|
||||
assert calls2 == [4]
|
||||
|
||||
# check don't kill the flow
|
||||
unsub()
|
||||
|
||||
dispatcher_send(self.hass, 'test1', 7)
|
||||
dispatcher_send(self.hass, 'test2', 8)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert calls1 == [3, 5, 7]
|
||||
assert calls2 == [4]
|
||||
|
||||
def test_simple_callback(self):
|
||||
"""Test simple callback (async)."""
|
||||
calls = []
|
||||
@ -50,8 +89,6 @@ class TestHelpersDispatcher(object):
|
||||
calls.append(data)
|
||||
|
||||
dispatcher_connect(self.hass, 'test', test_funct)
|
||||
self.hass.block_till_done()
|
||||
|
||||
dispatcher_send(self.hass, 'test', 3)
|
||||
self.hass.block_till_done()
|
||||
|
||||
@ -72,8 +109,6 @@ class TestHelpersDispatcher(object):
|
||||
calls.append(data)
|
||||
|
||||
dispatcher_connect(self.hass, 'test', test_funct)
|
||||
self.hass.block_till_done()
|
||||
|
||||
dispatcher_send(self.hass, 'test', 3)
|
||||
self.hass.block_till_done()
|
||||
|
||||
@ -95,8 +130,6 @@ class TestHelpersDispatcher(object):
|
||||
calls.append(data3)
|
||||
|
||||
dispatcher_connect(self.hass, 'test', test_funct)
|
||||
self.hass.block_till_done()
|
||||
|
||||
dispatcher_send(self.hass, 'test', 3, 2, 'bla')
|
||||
self.hass.block_till_done()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user