From 8fbb5858743dc4f356bdff3f0126899aad9c3aa8 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 17 Jan 2016 21:39:25 -0800 Subject: [PATCH] Fix MQTT reconnecting --- homeassistant/components/mqtt/__init__.py | 193 +++++++++++----------- tests/components/test_mqtt.py | 74 ++++++--- 2 files changed, 150 insertions(+), 117 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 86dce3d511b..c26f03a24f5 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -51,7 +51,7 @@ MAX_RECONNECT_WAIT = 300 # seconds def publish(hass, topic, payload, qos=None, retain=None): - """ Send an MQTT message. """ + """Publish message to an MQTT topic.""" data = { ATTR_TOPIC: topic, ATTR_PAYLOAD: payload, @@ -66,9 +66,9 @@ def publish(hass, topic, payload, qos=None, retain=None): def subscribe(hass, topic, callback, qos=DEFAULT_QOS): - """ Subscribe to a topic. """ + """Subscribe to an MQTT topic.""" def mqtt_topic_subscriber(event): - """ Match subscribed MQTT topic. """ + """Match subscribed MQTT topic.""" if _match_topic(topic, event.data[ATTR_TOPIC]): callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) @@ -78,8 +78,7 @@ def subscribe(hass, topic, callback, qos=DEFAULT_QOS): def setup(hass, config): - """ Get the MQTT protocol service. """ - + """Start the MQTT protocol service.""" if not validate_config(config, {DOMAIN: ['broker']}, _LOGGER): return False @@ -110,16 +109,16 @@ def setup(hass, config): return False def stop_mqtt(event): - """ Stop MQTT component. """ + """Stop MQTT component.""" MQTT_CLIENT.stop() def start_mqtt(event): - """ Launch MQTT component when Home Assistant starts up. """ + """Launch MQTT component when Home Assistant starts up.""" MQTT_CLIENT.start() hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_mqtt) def publish_service(call): - """ Handle MQTT publish service calls. """ + """Handle MQTT publish service calls.""" msg_topic = call.data.get(ATTR_TOPIC) payload = call.data.get(ATTR_PAYLOAD) qos = call.data.get(ATTR_QOS, DEFAULT_QOS) @@ -137,148 +136,156 @@ def setup(hass, config): # pylint: disable=too-many-arguments class MQTT(object): - """ Implements messaging service for MQTT. """ + """Home Assistant MQTT client.""" + def __init__(self, hass, broker, port, client_id, keepalive, username, password, certificate): + """Initialize Home Assistant MQTT client.""" import paho.mqtt.client as mqtt - self.userdata = { - 'hass': hass, - 'topics': {}, - 'progress': {}, - } + self.hass = hass + self.topics = {} + self.progress = {} if client_id is None: self._mqttc = mqtt.Client(protocol=mqtt.MQTTv311) else: self._mqttc = mqtt.Client(client_id, protocol=mqtt.MQTTv311) - self._mqttc.user_data_set(self.userdata) - if username is not None: self._mqttc.username_pw_set(username, password) if certificate is not None: self._mqttc.tls_set(certificate) - self._mqttc.on_subscribe = _mqtt_on_subscribe - self._mqttc.on_unsubscribe = _mqtt_on_unsubscribe - self._mqttc.on_connect = _mqtt_on_connect - self._mqttc.on_disconnect = _mqtt_on_disconnect - self._mqttc.on_message = _mqtt_on_message + self._mqttc.on_subscribe = self._mqtt_on_subscribe + self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe + self._mqttc.on_connect = self._mqtt_on_connect + self._mqttc.on_disconnect = self._mqtt_on_disconnect + self._mqttc.on_message = self._mqtt_on_message self._mqttc.connect(broker, port, keepalive) def publish(self, topic, payload, qos, retain): - """ Publish a MQTT message. """ + """Publish a MQTT message.""" self._mqttc.publish(topic, payload, qos, retain) def start(self): - """ Run the MQTT client. """ + """Run the MQTT client.""" self._mqttc.loop_start() def stop(self): - """ Stop the MQTT client. """ + """Stop the MQTT client.""" + self._mqttc.disconnect() self._mqttc.loop_stop() def subscribe(self, topic, qos): - """ Subscribe to a topic. """ - if topic in self.userdata['topics']: + """Subscribe to a topic.""" + assert isinstance(topic, str) + + if topic in self.topics: return result, mid = self._mqttc.subscribe(topic, qos) _raise_on_error(result) - self.userdata['progress'][mid] = topic - self.userdata['topics'][topic] = None + self.progress[mid] = topic + self.topics[topic] = None def unsubscribe(self, topic): - """ Unsubscribe from topic. """ + """Unsubscribe from topic.""" result, mid = self._mqttc.unsubscribe(topic) _raise_on_error(result) - self.userdata['progress'][mid] = topic + self.progress[mid] = topic + def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code): + """On connect callback. -def _mqtt_on_message(mqttc, userdata, msg): - """ Message callback """ - userdata['hass'].bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, { - ATTR_TOPIC: msg.topic, - ATTR_QOS: msg.qos, - ATTR_PAYLOAD: msg.payload.decode('utf-8'), - }) + Resubscribe to all topics we were subscribed to. + """ + if result_code != 0: + _LOGGER.error('Unable to connect to the MQTT broker: %s', { + 1: 'Incorrect protocol version', + 2: 'Invalid client identifier', + 3: 'Server unavailable', + 4: 'Bad username or password', + 5: 'Not authorised' + }.get(result_code, 'Unknown reason')) + self._mqttc.disconnect() + return + old_topics = self.topics -def _mqtt_on_connect(mqttc, userdata, flags, result_code): - """ On connect, resubscribe to all topics we were subscribed to. """ - if result_code != 0: - _LOGGER.error('Unable to connect to the MQTT broker: %s', { - 1: 'Incorrect protocol version', - 2: 'Invalid client identifier', - 3: 'Server unavailable', - 4: 'Bad username or password', - 5: 'Not authorised' - }.get(result_code, 'Unknown reason')) - mqttc.disconnect() - return + self.topics = {key: value for key, value in self.topics.items() + if value is None} - old_topics = userdata['topics'] + for topic, qos in old_topics.items(): + # qos is None if we were in process of subscribing + if qos is not None: + self.subscribe(topic, qos) - userdata['topics'] = {} - userdata['progress'] = {} + def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos): + """Subscribe successful callback.""" + topic = self.progress.pop(mid, None) + if topic is None: + return + self.topics[topic] = granted_qos[0] - for topic, qos in old_topics.items(): - # qos is None if we were in process of subscribing - if qos is not None: - mqttc.subscribe(topic, qos) + def _mqtt_on_message(self, _mqttc, _userdata, msg): + """Message received callback.""" + self.hass.bus.fire(EVENT_MQTT_MESSAGE_RECEIVED, { + ATTR_TOPIC: msg.topic, + ATTR_QOS: msg.qos, + ATTR_PAYLOAD: msg.payload.decode('utf-8'), + }) + def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos): + """Unsubscribe successful callback.""" + topic = self.progress.pop(mid, None) + if topic is None: + return + self.topics.pop(topic, None) -def _mqtt_on_subscribe(mqttc, userdata, mid, granted_qos): - """ Called when subscribe successful. """ - topic = userdata['progress'].pop(mid, None) - if topic is None: - return - userdata['topics'][topic] = granted_qos + def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): + """Disconnected callback.""" + self.progress = {} + self.topics = {key: value for key, value in self.topics.items() + if value is not None} + # Remove None values from topic list + for key in list(self.topics): + if self.topics[key] is None: + self.topics.pop(key) -def _mqtt_on_unsubscribe(mqttc, userdata, mid, granted_qos): - """ Called when subscribe successful. """ - topic = userdata['progress'].pop(mid, None) - if topic is None: - return - userdata['topics'].pop(topic, None) + # When disconnected because of calling disconnect() + if result_code == 0: + return + tries = 0 + wait_time = 0 -def _mqtt_on_disconnect(mqttc, userdata, result_code): - """ Called when being disconnected. """ - # When disconnected because of calling disconnect() - if result_code == 0: - return + while True: + try: + if self._mqttc.reconnect() == 0: + _LOGGER.info('Successfully reconnected to the MQTT server') + break + except socket.error: + pass - tries = 0 - wait_time = 0 - - while True: - try: - if mqttc.reconnect() == 0: - _LOGGER.info('Successfully reconnected to the MQTT server') - break - except socket.error: - pass - - wait_time = min(2**tries, MAX_RECONNECT_WAIT) - _LOGGER.warning( - 'Disconnected from MQTT (%s). Trying to reconnect in %ss', - result_code, wait_time) - # It is ok to sleep here as we are in the MQTT thread. - time.sleep(wait_time) - tries += 1 + wait_time = min(2**tries, MAX_RECONNECT_WAIT) + _LOGGER.warning( + 'Disconnected from MQTT (%s). Trying to reconnect in %ss', + result_code, wait_time) + # It is ok to sleep here as we are in the MQTT thread. + time.sleep(wait_time) + tries += 1 def _raise_on_error(result): - """ Raise error if error result. """ + """Raise error if error result.""" if result != 0: raise HomeAssistantError('Error talking to MQTT: {}'.format(result)) def _match_topic(subscription, topic): - """ Returns if topic matches subscription. """ + """Test if topic matches subscription.""" if subscription.endswith('#'): return (subscription[:-2] == topic or topic.startswith(subscription[:-1])) diff --git a/tests/components/test_mqtt.py b/tests/components/test_mqtt.py index 47a5ac7b4e1..40e473a3572 100644 --- a/tests/components/test_mqtt.py +++ b/tests/components/test_mqtt.py @@ -144,8 +144,15 @@ class TestMQTTCallbacks(unittest.TestCase): def setUp(self): # pylint: disable=invalid-name self.hass = get_test_home_assistant(1) - mock_mqtt_component(self.hass) - self.calls = [] + # mock_mqtt_component(self.hass) + + with mock.patch('paho.mqtt.client.Client'): + mqtt.setup(self.hass, { + mqtt.DOMAIN: { + mqtt.CONF_BROKER: 'mock-broker', + } + }) + self.hass.config.components.append(mqtt.DOMAIN) def tearDown(self): # pylint: disable=invalid-name """ Stop down stuff we started. """ @@ -162,7 +169,7 @@ class TestMQTTCallbacks(unittest.TestCase): MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) - mqtt._mqtt_on_message(None, {'hass': self.hass}, message) + mqtt.MQTT_CLIENT._mqtt_on_message(None, {'hass': self.hass}, message) self.hass.pool.block_till_done() self.assertEqual(1, len(calls)) @@ -173,36 +180,55 @@ class TestMQTTCallbacks(unittest.TestCase): def test_mqtt_failed_connection_results_in_disconnect(self): for result_code in range(1, 6): - mqttc = mock.MagicMock() - mqtt._mqtt_on_connect(mqttc, {'topics': {}}, 0, result_code) - self.assertTrue(mqttc.disconnect.called) + mqtt.MQTT_CLIENT._mqttc = mock.MagicMock() + mqtt.MQTT_CLIENT._mqtt_on_connect(None, {'topics': {}}, 0, + result_code) + self.assertTrue(mqtt.MQTT_CLIENT._mqttc.disconnect.called) def test_mqtt_subscribes_topics_on_connect(self): - prev_topics = { - 'topic/test': 1, - 'home/sensor': 2, - 'still/pending': None - } - mqttc = mock.MagicMock() - mqtt._mqtt_on_connect(mqttc, {'topics': prev_topics}, 0, 0) - self.assertFalse(mqttc.disconnect.called) + from collections import OrderedDict + prev_topics = OrderedDict() + prev_topics['topic/test'] = 1, + prev_topics['home/sensor'] = 2, + prev_topics['still/pending'] = None + + mqtt.MQTT_CLIENT.topics = prev_topics + mqtt.MQTT_CLIENT.progress = {1: 'still/pending'} + # Return values for subscribe calls (rc, mid) + mqtt.MQTT_CLIENT._mqttc.subscribe.side_effect = ((0, 2), (0, 3)) + mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0) + self.assertFalse(mqtt.MQTT_CLIENT._mqttc.disconnect.called) expected = [(topic, qos) for topic, qos in prev_topics.items() if qos is not None] - self.assertEqual(expected, [call[1] for call - in mqttc.subscribe.mock_calls]) + self.assertEqual( + expected, + [call[1] for call in mqtt.MQTT_CLIENT._mqttc.subscribe.mock_calls]) + self.assertEqual({ + 1: 'still/pending', + 2: 'topic/test', + 3: 'home/sensor', + }, mqtt.MQTT_CLIENT.progress) def test_mqtt_disconnect_tries_no_reconnect_on_stop(self): - mqttc = mock.MagicMock() - mqtt._mqtt_on_disconnect(mqttc, {}, 0) - self.assertFalse(mqttc.reconnect.called) + mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 0) + self.assertFalse(mqtt.MQTT_CLIENT._mqttc.reconnect.called) @mock.patch('homeassistant.components.mqtt.time.sleep') def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): - mqttc = mock.MagicMock() - mqttc.reconnect.side_effect = [1, 1, 1, 0] - mqtt._mqtt_on_disconnect(mqttc, {}, 1) - self.assertTrue(mqttc.reconnect.called) - self.assertEqual(4, len(mqttc.reconnect.mock_calls)) + mqtt.MQTT_CLIENT.topics = { + 'test/topic': 1, + 'test/progress': None + } + mqtt.MQTT_CLIENT.progress = { + 1: 'test/progress' + } + mqtt.MQTT_CLIENT._mqttc.reconnect.side_effect = [1, 1, 1, 0] + mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 1) + self.assertTrue(mqtt.MQTT_CLIENT._mqttc.reconnect.called) + self.assertEqual(4, len(mqtt.MQTT_CLIENT._mqttc.reconnect.mock_calls)) self.assertEqual([1, 2, 4], [call[1][0] for call in mock_sleep.mock_calls]) + + self.assertEqual({'test/topic': 1}, mqtt.MQTT_CLIENT.topics) + self.assertEqual({}, mqtt.MQTT_CLIENT.progress)