Wait for broker to ACK MQTT operations (#39051)

* Wait for broker to ACK MQTT operations

* Deduplicate new code

* Fix tests

* Improve test

* Don't hold PAHO lock when waiting for ACK

* Fix tests

* Add constant for ACK timeout
This commit is contained in:
Erik Montnemery 2020-08-21 17:00:13 +02:00 committed by GitHub
parent 4a7c181e91
commit ee043d8614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 38 deletions

View File

@ -126,6 +126,8 @@ CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
TIMEOUT_ACK = 1
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
"""Validate that a device info entry has at least one identifying value."""
@ -624,6 +626,8 @@ class MQTT:
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_operations = {}
self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated)
@ -707,6 +711,9 @@ class MQTT:
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
self._mqttc.on_publish = self._mqtt_on_callback
self._mqttc.on_subscribe = self._mqtt_on_callback
self._mqttc.on_unsubscribe = self._mqtt_on_callback
if (
CONF_WILL_MESSAGE in self.conf
@ -729,10 +736,17 @@ class MQTT:
) -> None:
"""Publish a MQTT message."""
async with self._paho_lock:
_LOGGER.debug("Transmitting message on %s: %s", topic, payload)
await self.hass.async_add_executor_job(
msg_info = await self.hass.async_add_executor_job(
self._mqttc.publish, topic, payload, qos, retain
)
_LOGGER.debug(
"Transmitting message on %s: '%s', mid: %s",
topic,
payload,
msg_info.mid,
)
_raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> str:
"""Connect to the host. Does not process messages yet."""
@ -810,24 +824,25 @@ class MQTT:
This method is a coroutine.
"""
_LOGGER.debug("Unsubscribing from %s", topic)
async with self._paho_lock:
result: int = None
result, _ = await self.hass.async_add_executor_job(
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topic
)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result)
await self._wait_for_mid(mid)
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
"""Perform a paho-mqtt subscription."""
_LOGGER.debug("Subscribing to %s", topic)
async with self._paho_lock:
result: int = None
result, _ = await self.hass.async_add_executor_job(
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, topic, qos
)
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
_raise_on_error(result)
await self._wait_for_mid(mid)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
"""On connect callback.
@ -919,6 +934,16 @@ class MQTT:
),
)
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
"""Publish / Subscribe / Unsubscribe callback."""
self.hass.add_job(self._mqtt_handle_mid, mid)
async def _mqtt_handle_mid(self, mid) -> None:
if mid in self._pending_operations:
self._pending_operations[mid].set()
else:
_LOGGER.warning("Unknown mid %d", mid)
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
self.connected = False
@ -930,6 +955,16 @@ class MQTT:
result_code,
)
async def _wait_for_mid(self, mid):
"""Wait for ACK from broker."""
self._pending_operations[mid] = asyncio.Event()
try:
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
except asyncio.TimeoutError:
_LOGGER.error("Timed out waiting for mid %s", mid)
finally:
del self._pending_operations[mid]
def _raise_on_error(result_code: int) -> None:
"""Raise error if error result."""

View File

@ -575,29 +575,37 @@ async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls
assert calls[0][0].payload == payload
async def test_retained_message_on_subscribe_received(
hass, mqtt_client_mock, mqtt_mock
):
"""Test every subscriber receives retained message on subscribe."""
async def test_subscribe_same_topic(hass, mqtt_client_mock, mqtt_mock):
"""
Test subscring to same topic twice and simulate retained messages.
def side_effect(*args):
async_fire_mqtt_message(hass, "test/state", "online")
return 0, 0
mqtt_client_mock.subscribe.side_effect = side_effect
When subscribing to the same topic again, SUBSCRIBE must be sent to the broker again
for it to resend any retained messages.
"""
# Fake that the client is connected
mqtt_mock().connected = True
calls_a = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_a)
async_fire_mqtt_message(
hass, "test/state", "online"
) # Simulate a (retained) message
await hass.async_block_till_done()
assert calls_a.called
mqtt_client_mock.subscribe.assert_called()
calls_a.reset_mock()
mqtt_client_mock.reset_mock()
calls_b = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_b)
async_fire_mqtt_message(
hass, "test/state", "online"
) # Simulate a (retained) message
await hass.async_block_till_done()
assert calls_a.called
assert calls_b.called
mqtt_client_mock.subscribe.assert_called()
async def test_not_calling_unsubscribe_with_active_subscribers(
@ -639,13 +647,6 @@ async def test_restore_all_active_subscriptions_on_reconnect(
# Fake that the client is connected
mqtt_mock().connected = True
mqtt_client_mock.subscribe.side_effect = (
(0, 1),
(0, 2),
(0, 3),
(0, 4),
)
unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2)
await mqtt.async_subscribe(hass, "test/state", None)
await mqtt.async_subscribe(hass, "test/state", None, qos=1)
@ -757,32 +758,28 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
)
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
calls = []
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
assert calls[-1] == ("birth", "birth", 0, False)
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
calls = []
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
assert calls[-1] == ("homeassistant/status", "online", 0, False)
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)
@pytest.mark.parametrize(
"mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: {}}],
)
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
calls = []
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
"""Test disabling birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
assert not calls
mqtt_client_mock.publish.assert_not_called()
@pytest.mark.parametrize(

View File

@ -335,15 +335,39 @@ def mqtt_config():
def mqtt_client_mock(hass):
"""Fixture to mock MQTT client."""
@ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
mid = 0
def get_mid():
nonlocal mid
mid += 1
return mid
class FakeInfo:
def __init__(self, mid):
self.mid = mid
self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
@ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
mid = get_mid()
mock_client.on_publish(0, 0, mid)
return FakeInfo(mid)
def _subscribe(topic, qos=0):
mock_client.on_subscribe(0, 0, mid)
return (0, mid)
def _unsubscribe(topic):
mock_client.on_unsubscribe(0, 0, mid)
return (0, mid)
mock_client = mock_client.return_value
mock_client.connect.return_value = 0
mock_client.subscribe.return_value = (0, 0)
mock_client.unsubscribe.return_value = (0, 0)
mock_client.subscribe.side_effect = _subscribe
mock_client.unsubscribe.side_effect = _unsubscribe
mock_client.publish.side_effect = _async_fire_mqtt_message
yield mock_client