mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Wait for broker to ACK MQTT operations (#39051)
* Wait for broker to ACK MQTT operations * Deduplicate new code * Fix tests * Improve test * Don't hold PAHO lock when waiting for ACK * Fix tests * Add constant for ACK timeout
This commit is contained in:
parent
4a7c181e91
commit
ee043d8614
@ -126,6 +126,8 @@ CONNECTION_SUCCESS = "connection_success"
|
||||
CONNECTION_FAILED = "connection_failed"
|
||||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||
|
||||
TIMEOUT_ACK = 1
|
||||
|
||||
|
||||
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
||||
"""Validate that a device info entry has at least one identifying value."""
|
||||
@ -624,6 +626,8 @@ class MQTT:
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._paho_lock = asyncio.Lock()
|
||||
|
||||
self._pending_operations = {}
|
||||
|
||||
self.init_client()
|
||||
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||
|
||||
@ -707,6 +711,9 @@ class MQTT:
|
||||
self._mqttc.on_connect = self._mqtt_on_connect
|
||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||
self._mqttc.on_message = self._mqtt_on_message
|
||||
self._mqttc.on_publish = self._mqtt_on_callback
|
||||
self._mqttc.on_subscribe = self._mqtt_on_callback
|
||||
self._mqttc.on_unsubscribe = self._mqtt_on_callback
|
||||
|
||||
if (
|
||||
CONF_WILL_MESSAGE in self.conf
|
||||
@ -729,10 +736,17 @@ class MQTT:
|
||||
) -> None:
|
||||
"""Publish a MQTT message."""
|
||||
async with self._paho_lock:
|
||||
_LOGGER.debug("Transmitting message on %s: %s", topic, payload)
|
||||
await self.hass.async_add_executor_job(
|
||||
msg_info = await self.hass.async_add_executor_job(
|
||||
self._mqttc.publish, topic, payload, qos, retain
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Transmitting message on %s: '%s', mid: %s",
|
||||
topic,
|
||||
payload,
|
||||
msg_info.mid,
|
||||
)
|
||||
_raise_on_error(msg_info.rc)
|
||||
await self._wait_for_mid(msg_info.mid)
|
||||
|
||||
async def async_connect(self) -> str:
|
||||
"""Connect to the host. Does not process messages yet."""
|
||||
@ -810,24 +824,25 @@ class MQTT:
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
_LOGGER.debug("Unsubscribing from %s", topic)
|
||||
async with self._paho_lock:
|
||||
result: int = None
|
||||
result, _ = await self.hass.async_add_executor_job(
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.unsubscribe, topic
|
||||
)
|
||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||
_raise_on_error(result)
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
||||
"""Perform a paho-mqtt subscription."""
|
||||
_LOGGER.debug("Subscribing to %s", topic)
|
||||
|
||||
async with self._paho_lock:
|
||||
result: int = None
|
||||
result, _ = await self.hass.async_add_executor_job(
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.subscribe, topic, qos
|
||||
)
|
||||
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
||||
_raise_on_error(result)
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
||||
"""On connect callback.
|
||||
@ -919,6 +934,16 @@ class MQTT:
|
||||
),
|
||||
)
|
||||
|
||||
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
|
||||
"""Publish / Subscribe / Unsubscribe callback."""
|
||||
self.hass.add_job(self._mqtt_handle_mid, mid)
|
||||
|
||||
async def _mqtt_handle_mid(self, mid) -> None:
|
||||
if mid in self._pending_operations:
|
||||
self._pending_operations[mid].set()
|
||||
else:
|
||||
_LOGGER.warning("Unknown mid %d", mid)
|
||||
|
||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||
"""Disconnected callback."""
|
||||
self.connected = False
|
||||
@ -930,6 +955,16 @@ class MQTT:
|
||||
result_code,
|
||||
)
|
||||
|
||||
async def _wait_for_mid(self, mid):
|
||||
"""Wait for ACK from broker."""
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
try:
|
||||
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.error("Timed out waiting for mid %s", mid)
|
||||
finally:
|
||||
del self._pending_operations[mid]
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int) -> None:
|
||||
"""Raise error if error result."""
|
||||
|
@ -575,29 +575,37 @@ async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls
|
||||
assert calls[0][0].payload == payload
|
||||
|
||||
|
||||
async def test_retained_message_on_subscribe_received(
|
||||
hass, mqtt_client_mock, mqtt_mock
|
||||
):
|
||||
"""Test every subscriber receives retained message on subscribe."""
|
||||
async def test_subscribe_same_topic(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""
|
||||
Test subscring to same topic twice and simulate retained messages.
|
||||
|
||||
def side_effect(*args):
|
||||
async_fire_mqtt_message(hass, "test/state", "online")
|
||||
return 0, 0
|
||||
|
||||
mqtt_client_mock.subscribe.side_effect = side_effect
|
||||
When subscribing to the same topic again, SUBSCRIBE must be sent to the broker again
|
||||
for it to resend any retained messages.
|
||||
"""
|
||||
|
||||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
|
||||
calls_a = MagicMock()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||
async_fire_mqtt_message(
|
||||
hass, "test/state", "online"
|
||||
) # Simulate a (retained) message
|
||||
await hass.async_block_till_done()
|
||||
assert calls_a.called
|
||||
mqtt_client_mock.subscribe.assert_called()
|
||||
calls_a.reset_mock()
|
||||
mqtt_client_mock.reset_mock()
|
||||
|
||||
calls_b = MagicMock()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
||||
async_fire_mqtt_message(
|
||||
hass, "test/state", "online"
|
||||
) # Simulate a (retained) message
|
||||
await hass.async_block_till_done()
|
||||
assert calls_a.called
|
||||
assert calls_b.called
|
||||
mqtt_client_mock.subscribe.assert_called()
|
||||
|
||||
|
||||
async def test_not_calling_unsubscribe_with_active_subscribers(
|
||||
@ -639,13 +647,6 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
|
||||
mqtt_client_mock.subscribe.side_effect = (
|
||||
(0, 1),
|
||||
(0, 2),
|
||||
(0, 3),
|
||||
(0, 4),
|
||||
)
|
||||
|
||||
unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2)
|
||||
await mqtt.async_subscribe(hass, "test/state", None)
|
||||
await mqtt.async_subscribe(hass, "test/state", None, qos=1)
|
||||
@ -757,32 +758,28 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
||||
)
|
||||
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test sending birth message."""
|
||||
calls = []
|
||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
assert calls[-1] == ("birth", "birth", 0, False)
|
||||
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
||||
|
||||
|
||||
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test sending birth message."""
|
||||
calls = []
|
||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
assert calls[-1] == ("homeassistant/status", "online", 0, False)
|
||||
mqtt_client_mock.publish.assert_called_with(
|
||||
"homeassistant/status", "online", 0, False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: {}}],
|
||||
)
|
||||
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test sending birth message."""
|
||||
calls = []
|
||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
||||
"""Test disabling birth message."""
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
assert not calls
|
||||
mqtt_client_mock.publish.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -335,15 +335,39 @@ def mqtt_config():
|
||||
def mqtt_client_mock(hass):
|
||||
"""Fixture to mock MQTT client."""
|
||||
|
||||
@ha.callback
|
||||
def _async_fire_mqtt_message(topic, payload, qos, retain):
|
||||
async_fire_mqtt_message(hass, topic, payload, qos, retain)
|
||||
mid = 0
|
||||
|
||||
def get_mid():
|
||||
nonlocal mid
|
||||
mid += 1
|
||||
return mid
|
||||
|
||||
class FakeInfo:
|
||||
def __init__(self, mid):
|
||||
self.mid = mid
|
||||
self.rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
|
||||
@ha.callback
|
||||
def _async_fire_mqtt_message(topic, payload, qos, retain):
|
||||
async_fire_mqtt_message(hass, topic, payload, qos, retain)
|
||||
mid = get_mid()
|
||||
mock_client.on_publish(0, 0, mid)
|
||||
return FakeInfo(mid)
|
||||
|
||||
def _subscribe(topic, qos=0):
|
||||
mock_client.on_subscribe(0, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
def _unsubscribe(topic):
|
||||
mock_client.on_unsubscribe(0, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
mock_client = mock_client.return_value
|
||||
mock_client.connect.return_value = 0
|
||||
mock_client.subscribe.return_value = (0, 0)
|
||||
mock_client.unsubscribe.return_value = (0, 0)
|
||||
mock_client.subscribe.side_effect = _subscribe
|
||||
mock_client.unsubscribe.side_effect = _unsubscribe
|
||||
mock_client.publish.side_effect = _async_fire_mqtt_message
|
||||
yield mock_client
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user