Compare commits

...

1 Commits

Author SHA1 Message Date
jbouwh
79738cfa0d Allow to wait for MQTT subscription 2025-09-25 18:54:38 +00:00
3 changed files with 103 additions and 1 deletions

View File

@@ -38,7 +38,10 @@ from homeassistant.core import (
get_hassjob_callable_job_type,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.importlib import async_import_module
from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import ConfigType
@@ -71,6 +74,7 @@ from .const import (
DEFAULT_WS_PATH,
DOMAIN,
MQTT_CONNECTION_STATE,
MQTT_PROCESSED_SUBSCRIPTIONS,
PROTOCOL_5,
PROTOCOL_31,
TRANSPORT_WEBSOCKETS,
@@ -109,6 +113,7 @@ INITIAL_SUBSCRIBE_COOLDOWN = 0.5
SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10
SUBSCRIBE_TIMEOUT = 10
RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
@@ -191,11 +196,47 @@ async def async_subscribe(
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
wait: bool = False,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
Call the return value to unsubscribe.
"""
subscription_complete: asyncio.Future[None]
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
if (topic, qos) not in subscriptions:
return
subscription_complete.set_result(None)
def _async_timeout_subscribe() -> None:
if not subscription_complete.done():
subscription_complete.set_exception(TimeoutError)
if (
wait
and DATA_MQTT in hass.data
and not hass.data[DATA_MQTT].client._matching_subscriptions(topic) # noqa: SLF001
):
subscription_complete = hass.loop.create_future()
dispatcher = async_dispatcher_connect(
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
)
subscribe_callback = async_subscribe_internal(
hass, topic, msg_callback, qos, encoding
)
try:
hass.loop.call_later(SUBSCRIBE_TIMEOUT, _async_timeout_subscribe)
await subscription_complete
except TimeoutError as exc:
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="subscribe_timeout",
) from exc
finally:
dispatcher()
return subscribe_callback
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
@@ -963,6 +1004,7 @@ class MQTT:
self._last_subscribe = time.monotonic()
await self._async_wait_for_mid_or_raise(mid, result)
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
async def _async_perform_unsubscribes(self) -> None:
"""Perform pending MQTT client unsubscribes."""

View File

@@ -370,6 +370,7 @@ DOMAIN = "mqtt"
LOGGER = logging.getLogger(__package__)
MQTT_CONNECTION_STATE = "mqtt_connection_state"
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None"

View File

@@ -282,6 +282,65 @@ async def test_subscribe_topic(
unsub()
async def test_subscribe_topic_and_wait(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mock_debouncer.wait()
mock_debouncer.clear()
unsub_no_wait = await mqtt.async_subscribe(hass, "other-test-topic/#", record_calls)
unsub_wait = await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
async_fire_mqtt_message(hass, "other-test-topic/test", "other-test-payload")
await hass.async_block_till_done()
assert len(recorded_calls) == 2
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "test-payload"
assert recorded_calls[1].topic == "other-test-topic/test"
assert recorded_calls[1].payload == "other-test-payload"
unsub_no_wait()
unsub_wait()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(recorded_calls) == 2
# Cannot unsubscribe twice
with pytest.raises(HomeAssistantError):
unsub_no_wait()
with pytest.raises(HomeAssistantError):
unsub_wait()
async def test_subscribe_topic_and_wait_timeout(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mock_debouncer.wait()
mock_debouncer.clear()
with (
patch("homeassistant.components.mqtt.client.SUBSCRIBE_TIMEOUT", 0),
pytest.raises(HomeAssistantError) as exc,
):
await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)
assert exc.value.translation_domain == mqtt.DOMAIN
assert exc.value.translation_key == "subscribe_timeout"
@pytest.mark.usefixtures("mqtt_mock_entry")
async def test_subscribe_topic_not_initialize(
hass: HomeAssistant, record_calls: MessageCallbackType