Compare commits

...

3 Commits

Author SHA1 Message Date
jbouwh
a784ec6454 Callback directly if the subscription is ready and the client is connected. 2025-10-29 11:25:34 +00:00
jbouwh
8c8b1df11f Use a callback handler instead of awaiting 2025-10-28 21:44:24 +00:00
jbouwh
aabcff9653 Allow to wait for MQTT subscription 2025-10-08 17:29:36 +00:00
4 changed files with 173 additions and 4 deletions

View File

@@ -40,6 +40,7 @@ from homeassistant.util.async_ import create_eager_task
from . import debug_info, discovery
from .client import (
MQTT,
async_on_subscribe_done,
async_publish,
async_subscribe,
async_subscribe_internal,
@@ -163,6 +164,7 @@ __all__ = [
"async_create_certificate_temp_files",
"async_forward_entry_setup_and_setup_discovery",
"async_migrate_entry",
"async_on_subscribe_done",
"async_prepare_subscribe_topics",
"async_publish",
"async_remove_config_entry_device",

View File

@@ -38,7 +38,10 @@ from homeassistant.core import (
get_hassjob_callable_job_type,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.importlib import async_import_module
from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import ConfigType
@@ -71,6 +74,7 @@ from .const import (
DEFAULT_WS_PATH,
DOMAIN,
MQTT_CONNECTION_STATE,
MQTT_PROCESSED_SUBSCRIPTIONS,
PROTOCOL_5,
PROTOCOL_31,
TRANSPORT_WEBSOCKETS,
@@ -109,6 +113,7 @@ INITIAL_SUBSCRIBE_COOLDOWN = 0.5
SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10
SUBSCRIBE_TIMEOUT = 10
RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
@@ -184,6 +189,38 @@ async def async_publish(
)
@callback
def async_on_subscribe_done(
hass: HomeAssistant,
topic: str,
qos: int,
on_subscribe_status: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Call on_subscribe_done when the matched subscription was completed.
If a subscription is already present the callback will call
on_subscribe_status directly.
Call the returned callback to stop and cleanup status monitoring.
"""
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
if (topic, qos) not in subscriptions:
return
hass.loop.call_soon(on_subscribe_status)
mqtt_data = hass.data[DATA_MQTT]
if (
mqtt_data.client.connected
and mqtt_data.client.is_active_subscription(topic)
and not mqtt_data.client.is_pending_subscription(topic)
):
hass.loop.call_soon(on_subscribe_status)
return async_dispatcher_connect(
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
)
@bind_hass
async def async_subscribe(
hass: HomeAssistant,
@@ -191,12 +228,32 @@ async def async_subscribe(
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
on_subscribe: CALLBACK_TYPE | None = None,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
If the on_subcribe callback hook is set, it will be called once
when the subscription has been completed.
Call the return value to unsubscribe.
"""
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
handler: CALLBACK_TYPE | None = None
def _on_subscribe_done() -> None:
"""Call once when the subscription was completed."""
if TYPE_CHECKING:
assert on_subscribe is not None and handler is not None
handler()
on_subscribe()
subscription_handler = async_subscribe_internal(
hass, topic, msg_callback, qos, encoding
)
if on_subscribe is not None:
handler = async_on_subscribe_done(hass, topic, qos, _on_subscribe_done)
return subscription_handler
@callback
@@ -640,12 +697,16 @@ class MQTT:
if fileno > -1:
self.loop.remove_writer(sock)
def _is_active_subscription(self, topic: str) -> bool:
def is_active_subscription(self, topic: str) -> bool:
"""Check if a topic has an active subscription."""
return topic in self._simple_subscriptions or any(
other.topic == topic for other in self._wildcard_subscriptions
)
def is_pending_subscription(self, topic: str) -> bool:
"""Check if a topic has a pending subscription."""
return topic in self._pending_subscriptions
async def async_publish(
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None:
@@ -899,7 +960,7 @@ class MQTT:
@callback
def _async_unsubscribe(self, topic: str) -> None:
"""Unsubscribe from a topic."""
if self._is_active_subscription(topic):
if self.is_active_subscription(topic):
if self._max_qos[topic] == 0:
return
subs = self._matching_subscriptions(topic)
@@ -963,6 +1024,7 @@ class MQTT:
self._last_subscribe = time.monotonic()
await self._async_wait_for_mid_or_raise(mid, result)
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
async def _async_perform_unsubscribes(self) -> None:
"""Perform pending MQTT client unsubscribes."""

View File

@@ -373,6 +373,7 @@ DOMAIN = "mqtt"
LOGGER = logging.getLogger(__package__)
MQTT_CONNECTION_STATE = "mqtt_connection_state"
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None"

View File

@@ -282,6 +282,100 @@ async def test_subscribe_topic(
unsub()
async def test_status_subscription_done(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the on subscription status."""
await mqtt_mock_entry()
on_status = asyncio.Event()
on_status_calls: list[bool] = []
def _on_subscribe_status() -> None:
on_status.set()
on_status_calls.append(True)
subscribe_callback = await mqtt.async_subscribe(
hass, "test-topic", record_calls, qos=0
)
handler = mqtt.async_on_subscribe_done(
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
)
await on_status.wait()
assert ("test-topic", 0) in help_all_subscribe_calls(mqtt_client_mock)
await mqtt.async_publish(hass, "test-topic", "beer ready", 0)
handler()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
# Test as we have an existing subscription, test we get a callback
recorded_calls.clear()
on_status.clear()
handler = mqtt.async_on_subscribe_done(
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
)
assert len(on_status_calls) == 1
await on_status.wait()
assert len(on_status_calls) == 2
# cleanup
handler()
subscribe_callback()
async def test_subscribe_topic_with_subscribe_done(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mqtt_mock_entry()
on_status = asyncio.Event()
def _on_subscribe() -> None:
hass.async_create_task(mqtt.async_publish(hass, "test-topic", "beer ready", 0))
on_status.set()
# Start a first subscription
unsub1 = await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
)
await on_status.wait()
await hass.async_block_till_done()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
recorded_calls.clear()
# Start a second subscription to the same topic
on_status.clear()
unsub2 = await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
)
await on_status.wait()
await hass.async_block_till_done()
assert len(recorded_calls) == 2
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
assert recorded_calls[1].topic == "test-topic"
assert recorded_calls[1].payload == "beer ready"
assert recorded_calls[1].qos == 0
unsub1()
unsub2()
@pytest.mark.usefixtures("mqtt_mock_entry")
async def test_subscribe_topic_not_initialize(
hass: HomeAssistant, record_calls: MessageCallbackType
@@ -292,6 +386,16 @@ async def test_subscribe_topic_not_initialize(
):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
def _on_subscribe_callback() -> None:
pass
with pytest.raises(
HomeAssistantError, match=r".*make sure MQTT is set up correctly"
):
await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe_callback
)
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType