diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 9ec5169c729..7f4ff030d36 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -6,10 +6,12 @@ MQTT component, using paho-mqtt. For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ +import json import logging import os import socket -import json +import time + from homeassistant.exceptions import HomeAssistantError import homeassistant.util as util @@ -45,6 +47,8 @@ ATTR_TOPIC = 'topic' ATTR_PAYLOAD = 'payload' ATTR_QOS = 'qos' +MAX_RECONNECT_WAIT = 300 # seconds + def publish(hass, topic, payload, qos=None): """ Send an MQTT message. """ @@ -66,9 +70,7 @@ def subscribe(hass, topic, callback, qos=DEFAULT_QOS): event.data[ATTR_QOS]) hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber) - - if topic not in MQTT_CLIENT.topics: - MQTT_CLIENT.subscribe(topic, qos) + MQTT_CLIENT.subscribe(topic, qos) def setup(hass, config): @@ -156,42 +158,42 @@ class FmtParser(object): # This is based on one of the paho-mqtt examples: # http://git.eclipse.org/c/paho/org.eclipse.paho.mqtt.python.git/tree/examples/sub-class.py # pylint: disable=too-many-arguments -class MQTT(object): # pragma: no cover +class MQTT(object): """ Implements messaging service for MQTT. """ def __init__(self, hass, broker, port, client_id, keepalive, username, password, certificate): import paho.mqtt.client as mqtt - self.hass = hass - self._progress = {} - self.topics = {} + self.userdata = { + 'hass': hass, + 'topics': {}, + 'progress': {}, + } if client_id is None: self._mqttc = mqtt.Client() else: self._mqttc = mqtt.Client(client_id) + self._mqttc.user_data_set(self.userdata) + if username is not None: self._mqttc.username_pw_set(username, password) if certificate is not None: self._mqttc.tls_set(certificate) - self._mqttc.on_subscribe = self._mqtt_on_subscribe - self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe - self._mqttc.on_connect = self._mqtt_on_connect - self._mqttc.on_message = self._mqtt_on_message + self._mqttc.on_subscribe = _mqtt_on_subscribe + self._mqttc.on_unsubscribe = _mqtt_on_unsubscribe + self._mqttc.on_connect = _mqtt_on_connect + self._mqttc.on_disconnect = _mqtt_on_disconnect + self._mqttc.on_message = _mqtt_on_message + self._mqttc.connect(broker, port, keepalive) def publish(self, topic, payload, qos): """ Publish a MQTT message. """ self._mqttc.publish(topic, payload, qos) - def unsubscribe(self, topic): - """ Unsubscribe from topic. """ - result, mid = self._mqttc.unsubscribe(topic) - _raise_on_error(result) - self._progress[mid] = topic - def start(self): """ Run the MQTT client. """ self._mqttc.loop_start() @@ -202,58 +204,96 @@ class MQTT(object): # pragma: no cover def subscribe(self, topic, qos): """ Subscribe to a topic. """ - if topic in self.topics: + if topic in self.userdata['topics']: return result, mid = self._mqttc.subscribe(topic, qos) _raise_on_error(result) - self._progress[mid] = topic - self.topics[topic] = None + self.userdata['progress'][mid] = topic + self.userdata['topics'][topic] = None - def _mqtt_on_connect(self, mqttc, obj, flags, result_code): - """ On connect, resubscribe to all topics we were subscribed to. """ - if result_code != 0: - _LOGGER.error('Unable to connect to the MQTT broker: %s', { - 1: 'Incorrect protocol version', - 2: 'Invalid client identifier', - 3: 'Server unavailable', - 4: 'Bad username or password', - 5: 'Not authorised' - }.get(result_code)) - self._mqttc.disconnect() - return - - old_topics = self.topics - self._progress = {} - self.topics = {} - for topic, qos in old_topics.items(): - # qos is None if we were in process of subscribing - if qos is not None: - self._mqttc.subscribe(topic, qos) - - def _mqtt_on_subscribe(self, mqttc, obj, mid, granted_qos): - """ Called when subscribe succesfull. """ - topic = self._progress.pop(mid, None) - if topic is None: - return - self.topics[topic] = granted_qos - - def _mqtt_on_unsubscribe(self, mqttc, obj, mid, granted_qos): - """ Called when subscribe succesfull. """ - topic = self._progress.pop(mid, None) - if topic is None: - return - self.topics.pop(topic, None) - - def _mqtt_on_message(self, mqttc, obj, msg): - """ Message callback """ - self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, { - ATTR_TOPIC: msg.topic, - ATTR_QOS: msg.qos, - ATTR_PAYLOAD: msg.payload.decode('utf-8'), - }) + def unsubscribe(self, topic): + """ Unsubscribe from topic. """ + result, mid = self._mqttc.unsubscribe(topic) + _raise_on_error(result) + self.userdata['progress'][mid] = topic -def _raise_on_error(result): # pragma: no cover +def _mqtt_on_message(mqttc, userdata, msg): + """ Message callback """ + userdata['hass'].bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, { + ATTR_TOPIC: msg.topic, + ATTR_QOS: msg.qos, + ATTR_PAYLOAD: msg.payload.decode('utf-8'), + }) + + +def _mqtt_on_connect(mqttc, userdata, flags, result_code): + """ On connect, resubscribe to all topics we were subscribed to. """ + if result_code != 0: + _LOGGER.error('Unable to connect to the MQTT broker: %s', { + 1: 'Incorrect protocol version', + 2: 'Invalid client identifier', + 3: 'Server unavailable', + 4: 'Bad username or password', + 5: 'Not authorised' + }.get(result_code, 'Unknown reason')) + mqttc.disconnect() + return + + old_topics = userdata['topics'] + + userdata['topics'] = {} + userdata['progress'] = {} + + for topic, qos in old_topics.items(): + # qos is None if we were in process of subscribing + if qos is not None: + mqttc.subscribe(topic, qos) + + +def _mqtt_on_subscribe(mqttc, userdata, mid, granted_qos): + """ Called when subscribe successfull. """ + topic = userdata['progress'].pop(mid, None) + if topic is None: + return + userdata['topics'][topic] = granted_qos + + +def _mqtt_on_unsubscribe(mqttc, userdata, mid, granted_qos): + """ Called when subscribe successfull. """ + topic = userdata['progress'].pop(mid, None) + if topic is None: + return + userdata['topics'].pop(topic, None) + + +def _mqtt_on_disconnect(mqttc, userdata, result_code): + """ Called when being disconnected. """ + # When disconnected because of calling disconnect() + if result_code == 0: + return + + tries = 0 + wait_time = 0 + + while True: + try: + if mqttc.reconnect() == 0: + _LOGGER.info('Successfully reconnected to the MQTT server') + break + except socket.error: + pass + + wait_time = min(2**tries, MAX_RECONNECT_WAIT) + _LOGGER.warning( + 'Disconnected from MQTT (%s). Trying to reconnect in %ss', + result_code, wait_time) + # It is ok to sleep here as we are in the MQTT thread. + time.sleep(wait_time) + tries += 1 + + +def _raise_on_error(result): """ Raise error if error result. """ if result != 0: raise HomeAssistantError('Error talking to MQTT: {}'.format(result)) diff --git a/tests/components/test_mqtt.py b/tests/components/test_mqtt.py index 4c3dbb1d20a..47a5ac7b4e1 100644 --- a/tests/components/test_mqtt.py +++ b/tests/components/test_mqtt.py @@ -4,6 +4,7 @@ tests.test_component_mqtt Tests MQTT component. """ +from collections import namedtuple import unittest from unittest import mock import socket @@ -17,8 +18,8 @@ from tests.common import ( get_test_home_assistant, mock_mqtt_component, fire_mqtt_message) -class TestDemo(unittest.TestCase): - """ Test the demo module. """ +class TestMQTT(unittest.TestCase): + """ Test the MQTT module. """ def setUp(self): # pylint: disable=invalid-name self.hass = get_test_home_assistant(1) @@ -136,3 +137,72 @@ class TestDemo(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(0, len(self.calls)) + + +class TestMQTTCallbacks(unittest.TestCase): + """ Test the MQTT callbacks. """ + + def setUp(self): # pylint: disable=invalid-name + self.hass = get_test_home_assistant(1) + mock_mqtt_component(self.hass) + self.calls = [] + + def tearDown(self): # pylint: disable=invalid-name + """ Stop down stuff we started. """ + self.hass.stop() + + def test_receiving_mqtt_message_fires_hass_event(self): + calls = [] + + def record(event): + calls.append(event) + + self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record) + + MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) + message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) + + mqtt._mqtt_on_message(None, {'hass': self.hass}, message) + self.hass.pool.block_till_done() + + self.assertEqual(1, len(calls)) + last_event = calls[0] + self.assertEqual('Hello World!', last_event.data['payload']) + self.assertEqual(message.topic, last_event.data['topic']) + self.assertEqual(message.qos, last_event.data['qos']) + + def test_mqtt_failed_connection_results_in_disconnect(self): + for result_code in range(1, 6): + mqttc = mock.MagicMock() + mqtt._mqtt_on_connect(mqttc, {'topics': {}}, 0, result_code) + self.assertTrue(mqttc.disconnect.called) + + def test_mqtt_subscribes_topics_on_connect(self): + prev_topics = { + 'topic/test': 1, + 'home/sensor': 2, + 'still/pending': None + } + mqttc = mock.MagicMock() + mqtt._mqtt_on_connect(mqttc, {'topics': prev_topics}, 0, 0) + self.assertFalse(mqttc.disconnect.called) + + expected = [(topic, qos) for topic, qos in prev_topics.items() + if qos is not None] + self.assertEqual(expected, [call[1] for call + in mqttc.subscribe.mock_calls]) + + def test_mqtt_disconnect_tries_no_reconnect_on_stop(self): + mqttc = mock.MagicMock() + mqtt._mqtt_on_disconnect(mqttc, {}, 0) + self.assertFalse(mqttc.reconnect.called) + + @mock.patch('homeassistant.components.mqtt.time.sleep') + def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): + mqttc = mock.MagicMock() + mqttc.reconnect.side_effect = [1, 1, 1, 0] + mqtt._mqtt_on_disconnect(mqttc, {}, 1) + self.assertTrue(mqttc.reconnect.called) + self.assertEqual(4, len(mqttc.reconnect.mock_calls)) + self.assertEqual([1, 2, 4], + [call[1][0] for call in mock_sleep.mock_calls])