Merge pull request #653 from balloob/mqtt-disconnect

Reconnect when disconnected from MQTT
This commit is contained in:
Paulus Schoutsen 2015-11-22 16:08:30 -08:00
commit 7acb3dffe4
2 changed files with 176 additions and 66 deletions

View File

@ -6,10 +6,12 @@ MQTT component, using paho-mqtt.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/ https://home-assistant.io/components/mqtt/
""" """
import json
import logging import logging
import os import os
import socket import socket
import json import time
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.util as util import homeassistant.util as util
@ -45,6 +47,8 @@ ATTR_TOPIC = 'topic'
ATTR_PAYLOAD = 'payload' ATTR_PAYLOAD = 'payload'
ATTR_QOS = 'qos' ATTR_QOS = 'qos'
MAX_RECONNECT_WAIT = 300 # seconds
def publish(hass, topic, payload, qos=None): def publish(hass, topic, payload, qos=None):
""" Send an MQTT message. """ """ Send an MQTT message. """
@ -66,8 +70,6 @@ def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
event.data[ATTR_QOS]) event.data[ATTR_QOS])
hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber) hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber)
if topic not in MQTT_CLIENT.topics:
MQTT_CLIENT.subscribe(topic, qos) MQTT_CLIENT.subscribe(topic, qos)
@ -156,42 +158,42 @@ class FmtParser(object):
# This is based on one of the paho-mqtt examples: # This is based on one of the paho-mqtt examples:
# http://git.eclipse.org/c/paho/org.eclipse.paho.mqtt.python.git/tree/examples/sub-class.py # http://git.eclipse.org/c/paho/org.eclipse.paho.mqtt.python.git/tree/examples/sub-class.py
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
class MQTT(object): # pragma: no cover class MQTT(object):
""" Implements messaging service for MQTT. """ """ Implements messaging service for MQTT. """
def __init__(self, hass, broker, port, client_id, keepalive, username, def __init__(self, hass, broker, port, client_id, keepalive, username,
password, certificate): password, certificate):
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
self.hass = hass self.userdata = {
self._progress = {} 'hass': hass,
self.topics = {} 'topics': {},
'progress': {},
}
if client_id is None: if client_id is None:
self._mqttc = mqtt.Client() self._mqttc = mqtt.Client()
else: else:
self._mqttc = mqtt.Client(client_id) self._mqttc = mqtt.Client(client_id)
self._mqttc.user_data_set(self.userdata)
if username is not None: if username is not None:
self._mqttc.username_pw_set(username, password) self._mqttc.username_pw_set(username, password)
if certificate is not None: if certificate is not None:
self._mqttc.tls_set(certificate) self._mqttc.tls_set(certificate)
self._mqttc.on_subscribe = self._mqtt_on_subscribe self._mqttc.on_subscribe = _mqtt_on_subscribe
self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe self._mqttc.on_unsubscribe = _mqtt_on_unsubscribe
self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_connect = _mqtt_on_connect
self._mqttc.on_message = self._mqtt_on_message self._mqttc.on_disconnect = _mqtt_on_disconnect
self._mqttc.on_message = _mqtt_on_message
self._mqttc.connect(broker, port, keepalive) self._mqttc.connect(broker, port, keepalive)
def publish(self, topic, payload, qos): def publish(self, topic, payload, qos):
""" Publish a MQTT message. """ """ Publish a MQTT message. """
self._mqttc.publish(topic, payload, qos) self._mqttc.publish(topic, payload, qos)
def unsubscribe(self, topic):
""" Unsubscribe from topic. """
result, mid = self._mqttc.unsubscribe(topic)
_raise_on_error(result)
self._progress[mid] = topic
def start(self): def start(self):
""" Run the MQTT client. """ """ Run the MQTT client. """
self._mqttc.loop_start() self._mqttc.loop_start()
@ -202,14 +204,30 @@ class MQTT(object): # pragma: no cover
def subscribe(self, topic, qos): def subscribe(self, topic, qos):
""" Subscribe to a topic. """ """ Subscribe to a topic. """
if topic in self.topics: if topic in self.userdata['topics']:
return return
result, mid = self._mqttc.subscribe(topic, qos) result, mid = self._mqttc.subscribe(topic, qos)
_raise_on_error(result) _raise_on_error(result)
self._progress[mid] = topic self.userdata['progress'][mid] = topic
self.topics[topic] = None self.userdata['topics'][topic] = None
def _mqtt_on_connect(self, mqttc, obj, flags, result_code): def unsubscribe(self, topic):
""" Unsubscribe from topic. """
result, mid = self._mqttc.unsubscribe(topic)
_raise_on_error(result)
self.userdata['progress'][mid] = topic
def _mqtt_on_message(mqttc, userdata, msg):
""" Message callback """
userdata['hass'].bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, {
ATTR_TOPIC: msg.topic,
ATTR_QOS: msg.qos,
ATTR_PAYLOAD: msg.payload.decode('utf-8'),
})
def _mqtt_on_connect(mqttc, userdata, flags, result_code):
""" On connect, resubscribe to all topics we were subscribed to. """ """ On connect, resubscribe to all topics we were subscribed to. """
if result_code != 0: if result_code != 0:
_LOGGER.error('Unable to connect to the MQTT broker: %s', { _LOGGER.error('Unable to connect to the MQTT broker: %s', {
@ -218,42 +236,64 @@ class MQTT(object): # pragma: no cover
3: 'Server unavailable', 3: 'Server unavailable',
4: 'Bad username or password', 4: 'Bad username or password',
5: 'Not authorised' 5: 'Not authorised'
}.get(result_code)) }.get(result_code, 'Unknown reason'))
self._mqttc.disconnect() mqttc.disconnect()
return return
old_topics = self.topics old_topics = userdata['topics']
self._progress = {}
self.topics = {} userdata['topics'] = {}
userdata['progress'] = {}
for topic, qos in old_topics.items(): for topic, qos in old_topics.items():
# qos is None if we were in process of subscribing # qos is None if we were in process of subscribing
if qos is not None: if qos is not None:
self._mqttc.subscribe(topic, qos) mqttc.subscribe(topic, qos)
def _mqtt_on_subscribe(self, mqttc, obj, mid, granted_qos):
""" Called when subscribe succesfull. """ def _mqtt_on_subscribe(mqttc, userdata, mid, granted_qos):
topic = self._progress.pop(mid, None) """ Called when subscribe successfull. """
topic = userdata['progress'].pop(mid, None)
if topic is None: if topic is None:
return return
self.topics[topic] = granted_qos userdata['topics'][topic] = granted_qos
def _mqtt_on_unsubscribe(self, mqttc, obj, mid, granted_qos):
""" Called when subscribe succesfull. """ def _mqtt_on_unsubscribe(mqttc, userdata, mid, granted_qos):
topic = self._progress.pop(mid, None) """ Called when subscribe successfull. """
topic = userdata['progress'].pop(mid, None)
if topic is None: if topic is None:
return return
self.topics.pop(topic, None) userdata['topics'].pop(topic, None)
def _mqtt_on_message(self, mqttc, obj, msg):
""" Message callback """
self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, {
ATTR_TOPIC: msg.topic,
ATTR_QOS: msg.qos,
ATTR_PAYLOAD: msg.payload.decode('utf-8'),
})
def _raise_on_error(result): # pragma: no cover def _mqtt_on_disconnect(mqttc, userdata, result_code):
""" Called when being disconnected. """
# When disconnected because of calling disconnect()
if result_code == 0:
return
tries = 0
wait_time = 0
while True:
try:
if mqttc.reconnect() == 0:
_LOGGER.info('Successfully reconnected to the MQTT server')
break
except socket.error:
pass
wait_time = min(2**tries, MAX_RECONNECT_WAIT)
_LOGGER.warning(
'Disconnected from MQTT (%s). Trying to reconnect in %ss',
result_code, wait_time)
# It is ok to sleep here as we are in the MQTT thread.
time.sleep(wait_time)
tries += 1
def _raise_on_error(result):
""" Raise error if error result. """ """ Raise error if error result. """
if result != 0: if result != 0:
raise HomeAssistantError('Error talking to MQTT: {}'.format(result)) raise HomeAssistantError('Error talking to MQTT: {}'.format(result))

View File

@ -4,6 +4,7 @@ tests.test_component_mqtt
Tests MQTT component. Tests MQTT component.
""" """
from collections import namedtuple
import unittest import unittest
from unittest import mock from unittest import mock
import socket import socket
@ -17,8 +18,8 @@ from tests.common import (
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message) get_test_home_assistant, mock_mqtt_component, fire_mqtt_message)
class TestDemo(unittest.TestCase): class TestMQTT(unittest.TestCase):
""" Test the demo module. """ """ Test the MQTT module. """
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
self.hass = get_test_home_assistant(1) self.hass = get_test_home_assistant(1)
@ -136,3 +137,72 @@ class TestDemo(unittest.TestCase):
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertEqual(0, len(self.calls)) self.assertEqual(0, len(self.calls))
class TestMQTTCallbacks(unittest.TestCase):
""" Test the MQTT callbacks. """
def setUp(self): # pylint: disable=invalid-name
self.hass = get_test_home_assistant(1)
mock_mqtt_component(self.hass)
self.calls = []
def tearDown(self): # pylint: disable=invalid-name
""" Stop down stuff we started. """
self.hass.stop()
def test_receiving_mqtt_message_fires_hass_event(self):
calls = []
def record(event):
calls.append(event)
self.hass.bus.listen_once(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, record)
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
mqtt._mqtt_on_message(None, {'hass': self.hass}, message)
self.hass.pool.block_till_done()
self.assertEqual(1, len(calls))
last_event = calls[0]
self.assertEqual('Hello World!', last_event.data['payload'])
self.assertEqual(message.topic, last_event.data['topic'])
self.assertEqual(message.qos, last_event.data['qos'])
def test_mqtt_failed_connection_results_in_disconnect(self):
for result_code in range(1, 6):
mqttc = mock.MagicMock()
mqtt._mqtt_on_connect(mqttc, {'topics': {}}, 0, result_code)
self.assertTrue(mqttc.disconnect.called)
def test_mqtt_subscribes_topics_on_connect(self):
prev_topics = {
'topic/test': 1,
'home/sensor': 2,
'still/pending': None
}
mqttc = mock.MagicMock()
mqtt._mqtt_on_connect(mqttc, {'topics': prev_topics}, 0, 0)
self.assertFalse(mqttc.disconnect.called)
expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
self.assertEqual(expected, [call[1] for call
in mqttc.subscribe.mock_calls])
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
mqttc = mock.MagicMock()
mqtt._mqtt_on_disconnect(mqttc, {}, 0)
self.assertFalse(mqttc.reconnect.called)
@mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
mqttc = mock.MagicMock()
mqttc.reconnect.side_effect = [1, 1, 1, 0]
mqtt._mqtt_on_disconnect(mqttc, {}, 1)
self.assertTrue(mqttc.reconnect.called)
self.assertEqual(4, len(mqttc.reconnect.mock_calls))
self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls])