mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Wait for broker to ACK MQTT operations (#39051)
* Wait for broker to ACK MQTT operations * Deduplicate new code * Fix tests * Improve test * Don't hold PAHO lock when waiting for ACK * Fix tests * Add constant for ACK timeout
This commit is contained in:
parent
4a7c181e91
commit
ee043d8614
@ -126,6 +126,8 @@ CONNECTION_SUCCESS = "connection_success"
|
|||||||
CONNECTION_FAILED = "connection_failed"
|
CONNECTION_FAILED = "connection_failed"
|
||||||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||||
|
|
||||||
|
TIMEOUT_ACK = 1
|
||||||
|
|
||||||
|
|
||||||
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
||||||
"""Validate that a device info entry has at least one identifying value."""
|
"""Validate that a device info entry has at least one identifying value."""
|
||||||
@ -624,6 +626,8 @@ class MQTT:
|
|||||||
self._mqttc: mqtt.Client = None
|
self._mqttc: mqtt.Client = None
|
||||||
self._paho_lock = asyncio.Lock()
|
self._paho_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
self._pending_operations = {}
|
||||||
|
|
||||||
self.init_client()
|
self.init_client()
|
||||||
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||||
|
|
||||||
@ -707,6 +711,9 @@ class MQTT:
|
|||||||
self._mqttc.on_connect = self._mqtt_on_connect
|
self._mqttc.on_connect = self._mqtt_on_connect
|
||||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||||
self._mqttc.on_message = self._mqtt_on_message
|
self._mqttc.on_message = self._mqtt_on_message
|
||||||
|
self._mqttc.on_publish = self._mqtt_on_callback
|
||||||
|
self._mqttc.on_subscribe = self._mqtt_on_callback
|
||||||
|
self._mqttc.on_unsubscribe = self._mqtt_on_callback
|
||||||
|
|
||||||
if (
|
if (
|
||||||
CONF_WILL_MESSAGE in self.conf
|
CONF_WILL_MESSAGE in self.conf
|
||||||
@ -729,10 +736,17 @@ class MQTT:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Publish a MQTT message."""
|
"""Publish a MQTT message."""
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
_LOGGER.debug("Transmitting message on %s: %s", topic, payload)
|
msg_info = await self.hass.async_add_executor_job(
|
||||||
await self.hass.async_add_executor_job(
|
|
||||||
self._mqttc.publish, topic, payload, qos, retain
|
self._mqttc.publish, topic, payload, qos, retain
|
||||||
)
|
)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Transmitting message on %s: '%s', mid: %s",
|
||||||
|
topic,
|
||||||
|
payload,
|
||||||
|
msg_info.mid,
|
||||||
|
)
|
||||||
|
_raise_on_error(msg_info.rc)
|
||||||
|
await self._wait_for_mid(msg_info.mid)
|
||||||
|
|
||||||
async def async_connect(self) -> str:
|
async def async_connect(self) -> str:
|
||||||
"""Connect to the host. Does not process messages yet."""
|
"""Connect to the host. Does not process messages yet."""
|
||||||
@ -810,24 +824,25 @@ class MQTT:
|
|||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Unsubscribing from %s", topic)
|
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
result: int = None
|
result: int = None
|
||||||
result, _ = await self.hass.async_add_executor_job(
|
result, mid = await self.hass.async_add_executor_job(
|
||||||
self._mqttc.unsubscribe, topic
|
self._mqttc.unsubscribe, topic
|
||||||
)
|
)
|
||||||
|
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
|
await self._wait_for_mid(mid)
|
||||||
|
|
||||||
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
||||||
"""Perform a paho-mqtt subscription."""
|
"""Perform a paho-mqtt subscription."""
|
||||||
_LOGGER.debug("Subscribing to %s", topic)
|
|
||||||
|
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
result: int = None
|
result: int = None
|
||||||
result, _ = await self.hass.async_add_executor_job(
|
result, mid = await self.hass.async_add_executor_job(
|
||||||
self._mqttc.subscribe, topic, qos
|
self._mqttc.subscribe, topic, qos
|
||||||
)
|
)
|
||||||
|
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
|
await self._wait_for_mid(mid)
|
||||||
|
|
||||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
||||||
"""On connect callback.
|
"""On connect callback.
|
||||||
@ -919,6 +934,16 @@ class MQTT:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
|
||||||
|
"""Publish / Subscribe / Unsubscribe callback."""
|
||||||
|
self.hass.add_job(self._mqtt_handle_mid, mid)
|
||||||
|
|
||||||
|
async def _mqtt_handle_mid(self, mid) -> None:
|
||||||
|
if mid in self._pending_operations:
|
||||||
|
self._pending_operations[mid].set()
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Unknown mid %d", mid)
|
||||||
|
|
||||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
self.connected = False
|
self.connected = False
|
||||||
@ -930,6 +955,16 @@ class MQTT:
|
|||||||
result_code,
|
result_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _wait_for_mid(self, mid):
|
||||||
|
"""Wait for ACK from broker."""
|
||||||
|
self._pending_operations[mid] = asyncio.Event()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_LOGGER.error("Timed out waiting for mid %s", mid)
|
||||||
|
finally:
|
||||||
|
del self._pending_operations[mid]
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result_code: int) -> None:
|
def _raise_on_error(result_code: int) -> None:
|
||||||
"""Raise error if error result."""
|
"""Raise error if error result."""
|
||||||
|
@ -575,29 +575,37 @@ async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls
|
|||||||
assert calls[0][0].payload == payload
|
assert calls[0][0].payload == payload
|
||||||
|
|
||||||
|
|
||||||
async def test_retained_message_on_subscribe_received(
|
async def test_subscribe_same_topic(hass, mqtt_client_mock, mqtt_mock):
|
||||||
hass, mqtt_client_mock, mqtt_mock
|
"""
|
||||||
):
|
Test subscring to same topic twice and simulate retained messages.
|
||||||
"""Test every subscriber receives retained message on subscribe."""
|
|
||||||
|
|
||||||
def side_effect(*args):
|
When subscribing to the same topic again, SUBSCRIBE must be sent to the broker again
|
||||||
async_fire_mqtt_message(hass, "test/state", "online")
|
for it to resend any retained messages.
|
||||||
return 0, 0
|
"""
|
||||||
|
|
||||||
mqtt_client_mock.subscribe.side_effect = side_effect
|
|
||||||
|
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock().connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
calls_a = MagicMock()
|
calls_a = MagicMock()
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||||
|
async_fire_mqtt_message(
|
||||||
|
hass, "test/state", "online"
|
||||||
|
) # Simulate a (retained) message
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert calls_a.called
|
assert calls_a.called
|
||||||
|
mqtt_client_mock.subscribe.assert_called()
|
||||||
|
calls_a.reset_mock()
|
||||||
|
mqtt_client_mock.reset_mock()
|
||||||
|
|
||||||
calls_b = MagicMock()
|
calls_b = MagicMock()
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
||||||
|
async_fire_mqtt_message(
|
||||||
|
hass, "test/state", "online"
|
||||||
|
) # Simulate a (retained) message
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
assert calls_a.called
|
||||||
assert calls_b.called
|
assert calls_b.called
|
||||||
|
mqtt_client_mock.subscribe.assert_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_not_calling_unsubscribe_with_active_subscribers(
|
async def test_not_calling_unsubscribe_with_active_subscribers(
|
||||||
@ -639,13 +647,6 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
|||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock().connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
mqtt_client_mock.subscribe.side_effect = (
|
|
||||||
(0, 1),
|
|
||||||
(0, 2),
|
|
||||||
(0, 3),
|
|
||||||
(0, 4),
|
|
||||||
)
|
|
||||||
|
|
||||||
unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2)
|
unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2)
|
||||||
await mqtt.async_subscribe(hass, "test/state", None)
|
await mqtt.async_subscribe(hass, "test/state", None)
|
||||||
await mqtt.async_subscribe(hass, "test/state", None, qos=1)
|
await mqtt.async_subscribe(hass, "test/state", None, qos=1)
|
||||||
@ -757,32 +758,28 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
|||||||
)
|
)
|
||||||
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test sending birth message."""
|
"""Test sending birth message."""
|
||||||
calls = []
|
|
||||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert calls[-1] == ("birth", "birth", 0, False)
|
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
||||||
|
|
||||||
|
|
||||||
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test sending birth message."""
|
"""Test sending birth message."""
|
||||||
calls = []
|
|
||||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert calls[-1] == ("homeassistant/status", "online", 0, False)
|
mqtt_client_mock.publish.assert_called_with(
|
||||||
|
"homeassistant/status", "online", 0, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: {}}],
|
"mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: {}}],
|
||||||
)
|
)
|
||||||
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test sending birth message."""
|
"""Test disabling birth message."""
|
||||||
calls = []
|
|
||||||
mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
|
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert not calls
|
mqtt_client_mock.publish.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -335,15 +335,39 @@ def mqtt_config():
|
|||||||
def mqtt_client_mock(hass):
|
def mqtt_client_mock(hass):
|
||||||
"""Fixture to mock MQTT client."""
|
"""Fixture to mock MQTT client."""
|
||||||
|
|
||||||
|
mid = 0
|
||||||
|
|
||||||
|
def get_mid():
|
||||||
|
nonlocal mid
|
||||||
|
mid += 1
|
||||||
|
return mid
|
||||||
|
|
||||||
|
class FakeInfo:
|
||||||
|
def __init__(self, mid):
|
||||||
|
self.mid = mid
|
||||||
|
self.rc = 0
|
||||||
|
|
||||||
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
|
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def _async_fire_mqtt_message(topic, payload, qos, retain):
|
def _async_fire_mqtt_message(topic, payload, qos, retain):
|
||||||
async_fire_mqtt_message(hass, topic, payload, qos, retain)
|
async_fire_mqtt_message(hass, topic, payload, qos, retain)
|
||||||
|
mid = get_mid()
|
||||||
|
mock_client.on_publish(0, 0, mid)
|
||||||
|
return FakeInfo(mid)
|
||||||
|
|
||||||
|
def _subscribe(topic, qos=0):
|
||||||
|
mock_client.on_subscribe(0, 0, mid)
|
||||||
|
return (0, mid)
|
||||||
|
|
||||||
|
def _unsubscribe(topic):
|
||||||
|
mock_client.on_unsubscribe(0, 0, mid)
|
||||||
|
return (0, mid)
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
|
||||||
mock_client = mock_client.return_value
|
mock_client = mock_client.return_value
|
||||||
mock_client.connect.return_value = 0
|
mock_client.connect.return_value = 0
|
||||||
mock_client.subscribe.return_value = (0, 0)
|
mock_client.subscribe.side_effect = _subscribe
|
||||||
mock_client.unsubscribe.return_value = (0, 0)
|
mock_client.unsubscribe.side_effect = _unsubscribe
|
||||||
mock_client.publish.side_effect = _async_fire_mqtt_message
|
mock_client.publish.side_effect = _async_fire_mqtt_message
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user