diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 702e8a139f4..87258d17f99 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -126,6 +126,8 @@ CONNECTION_SUCCESS = "connection_success" CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" +TIMEOUT_ACK = 1 + def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType: """Validate that a device info entry has at least one identifying value.""" @@ -624,6 +626,8 @@ class MQTT: self._mqttc: mqtt.Client = None self._paho_lock = asyncio.Lock() + self._pending_operations = {} + self.init_client() self.config_entry.add_update_listener(self.async_config_entry_updated) @@ -707,6 +711,9 @@ class MQTT: self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message + self._mqttc.on_publish = self._mqtt_on_callback + self._mqttc.on_subscribe = self._mqtt_on_callback + self._mqttc.on_unsubscribe = self._mqtt_on_callback if ( CONF_WILL_MESSAGE in self.conf @@ -729,10 +736,17 @@ class MQTT: ) -> None: """Publish a MQTT message.""" async with self._paho_lock: - _LOGGER.debug("Transmitting message on %s: %s", topic, payload) - await self.hass.async_add_executor_job( + msg_info = await self.hass.async_add_executor_job( self._mqttc.publish, topic, payload, qos, retain ) + _LOGGER.debug( + "Transmitting message on %s: '%s', mid: %s", + topic, + payload, + msg_info.mid, + ) + _raise_on_error(msg_info.rc) + await self._wait_for_mid(msg_info.mid) async def async_connect(self) -> str: """Connect to the host. Does not process messages yet.""" @@ -810,24 +824,25 @@ class MQTT: This method is a coroutine. """ - _LOGGER.debug("Unsubscribing from %s", topic) async with self._paho_lock: result: int = None - result, _ = await self.hass.async_add_executor_job( + result, mid = await self.hass.async_add_executor_job( self._mqttc.unsubscribe, topic ) + _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _raise_on_error(result) + await self._wait_for_mid(mid) async def _async_perform_subscription(self, topic: str, qos: int) -> None: """Perform a paho-mqtt subscription.""" - _LOGGER.debug("Subscribing to %s", topic) - async with self._paho_lock: result: int = None - result, _ = await self.hass.async_add_executor_job( + result, mid = await self.hass.async_add_executor_job( self._mqttc.subscribe, topic, qos ) + _LOGGER.debug("Subscribing to %s, mid: %s", topic, mid) _raise_on_error(result) + await self._wait_for_mid(mid) def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: """On connect callback. @@ -919,6 +934,16 @@ class MQTT: ), ) + def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None: + """Publish / Subscribe / Unsubscribe callback.""" + self.hass.add_job(self._mqtt_handle_mid, mid) + + async def _mqtt_handle_mid(self, mid) -> None: + if mid in self._pending_operations: + self._pending_operations[mid].set() + else: + _LOGGER.warning("Unknown mid %d", mid) + def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" self.connected = False @@ -930,6 +955,16 @@ class MQTT: result_code, ) + async def _wait_for_mid(self, mid): + """Wait for ACK from broker.""" + self._pending_operations[mid] = asyncio.Event() + try: + await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK) + except asyncio.TimeoutError: + _LOGGER.error("Timed out waiting for mid %s", mid) + finally: + del self._pending_operations[mid] + def _raise_on_error(result_code: int) -> None: """Raise error if error result.""" diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 15d92b9a311..dea0852d580 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -575,29 +575,37 @@ async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls assert calls[0][0].payload == payload -async def test_retained_message_on_subscribe_received( - hass, mqtt_client_mock, mqtt_mock -): - """Test every subscriber receives retained message on subscribe.""" +async def test_subscribe_same_topic(hass, mqtt_client_mock, mqtt_mock): + """ + Test subscring to same topic twice and simulate retained messages. - def side_effect(*args): - async_fire_mqtt_message(hass, "test/state", "online") - return 0, 0 - - mqtt_client_mock.subscribe.side_effect = side_effect + When subscribing to the same topic again, SUBSCRIBE must be sent to the broker again + for it to resend any retained messages. + """ # Fake that the client is connected mqtt_mock().connected = True calls_a = MagicMock() await mqtt.async_subscribe(hass, "test/state", calls_a) + async_fire_mqtt_message( + hass, "test/state", "online" + ) # Simulate a (retained) message await hass.async_block_till_done() assert calls_a.called + mqtt_client_mock.subscribe.assert_called() + calls_a.reset_mock() + mqtt_client_mock.reset_mock() calls_b = MagicMock() await mqtt.async_subscribe(hass, "test/state", calls_b) + async_fire_mqtt_message( + hass, "test/state", "online" + ) # Simulate a (retained) message await hass.async_block_till_done() + assert calls_a.called assert calls_b.called + mqtt_client_mock.subscribe.assert_called() async def test_not_calling_unsubscribe_with_active_subscribers( @@ -639,13 +647,6 @@ async def test_restore_all_active_subscriptions_on_reconnect( # Fake that the client is connected mqtt_mock().connected = True - mqtt_client_mock.subscribe.side_effect = ( - (0, 1), - (0, 2), - (0, 3), - (0, 4), - ) - unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2) await mqtt.async_subscribe(hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None, qos=1) @@ -757,32 +758,28 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass): ) async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test sending birth message.""" - calls = [] - mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args) mqtt_mock._mqtt_on_connect(None, None, 0, 0) await hass.async_block_till_done() - assert calls[-1] == ("birth", "birth", 0, False) + mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False) async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test sending birth message.""" - calls = [] - mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args) mqtt_mock._mqtt_on_connect(None, None, 0, 0) await hass.async_block_till_done() - assert calls[-1] == ("homeassistant/status", "online", 0, False) + mqtt_client_mock.publish.assert_called_with( + "homeassistant/status", "online", 0, False + ) @pytest.mark.parametrize( "mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: {}}], ) async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): - """Test sending birth message.""" - calls = [] - mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args) + """Test disabling birth message.""" mqtt_mock._mqtt_on_connect(None, None, 0, 0) await hass.async_block_till_done() - assert not calls + mqtt_client_mock.publish.assert_not_called() @pytest.mark.parametrize( diff --git a/tests/conftest.py b/tests/conftest.py index ba02ae5de2f..9008359e539 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -335,15 +335,39 @@ def mqtt_config(): def mqtt_client_mock(hass): """Fixture to mock MQTT client.""" - @ha.callback - def _async_fire_mqtt_message(topic, payload, qos, retain): - async_fire_mqtt_message(hass, topic, payload, qos, retain) + mid = 0 + + def get_mid(): + nonlocal mid + mid += 1 + return mid + + class FakeInfo: + def __init__(self, mid): + self.mid = mid + self.rc = 0 with patch("paho.mqtt.client.Client") as mock_client: + + @ha.callback + def _async_fire_mqtt_message(topic, payload, qos, retain): + async_fire_mqtt_message(hass, topic, payload, qos, retain) + mid = get_mid() + mock_client.on_publish(0, 0, mid) + return FakeInfo(mid) + + def _subscribe(topic, qos=0): + mock_client.on_subscribe(0, 0, mid) + return (0, mid) + + def _unsubscribe(topic): + mock_client.on_unsubscribe(0, 0, mid) + return (0, mid) + mock_client = mock_client.return_value mock_client.connect.return_value = 0 - mock_client.subscribe.return_value = (0, 0) - mock_client.unsubscribe.return_value = (0, 0) + mock_client.subscribe.side_effect = _subscribe + mock_client.unsubscribe.side_effect = _unsubscribe mock_client.publish.side_effect = _async_fire_mqtt_message yield mock_client