Compare commits

...

3 Commits

Author SHA1 Message Date
jbouwh
a784ec6454 Callback directly if the subscription is ready and the client is connected. 2025-10-29 11:25:34 +00:00
jbouwh
8c8b1df11f Use a callback handler instead of awaiting 2025-10-28 21:44:24 +00:00
jbouwh
aabcff9653 Allow to wait for MQTT subscription 2025-10-08 17:29:36 +00:00
4 changed files with 173 additions and 4 deletions

View File

@@ -40,6 +40,7 @@ from homeassistant.util.async_ import create_eager_task
from . import debug_info, discovery from . import debug_info, discovery
from .client import ( from .client import (
MQTT, MQTT,
async_on_subscribe_done,
async_publish, async_publish,
async_subscribe, async_subscribe,
async_subscribe_internal, async_subscribe_internal,
@@ -163,6 +164,7 @@ __all__ = [
"async_create_certificate_temp_files", "async_create_certificate_temp_files",
"async_forward_entry_setup_and_setup_discovery", "async_forward_entry_setup_and_setup_discovery",
"async_migrate_entry", "async_migrate_entry",
"async_on_subscribe_done",
"async_prepare_subscribe_topics", "async_prepare_subscribe_topics",
"async_publish", "async_publish",
"async_remove_config_entry_device", "async_remove_config_entry_device",

View File

@@ -38,7 +38,10 @@ from homeassistant.core import (
get_hassjob_callable_job_type, get_hassjob_callable_job_type,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.importlib import async_import_module from homeassistant.helpers.importlib import async_import_module
from homeassistant.helpers.start import async_at_started from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@@ -71,6 +74,7 @@ from .const import (
DEFAULT_WS_PATH, DEFAULT_WS_PATH,
DOMAIN, DOMAIN,
MQTT_CONNECTION_STATE, MQTT_CONNECTION_STATE,
MQTT_PROCESSED_SUBSCRIPTIONS,
PROTOCOL_5, PROTOCOL_5,
PROTOCOL_31, PROTOCOL_31,
TRANSPORT_WEBSOCKETS, TRANSPORT_WEBSOCKETS,
@@ -109,6 +113,7 @@ INITIAL_SUBSCRIBE_COOLDOWN = 0.5
SUBSCRIBE_COOLDOWN = 0.1 SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1 UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10 TIMEOUT_ACK = 10
SUBSCRIBE_TIMEOUT = 10
RECONNECT_INTERVAL_SECONDS = 10 RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1 MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
@@ -184,6 +189,38 @@ async def async_publish(
) )
@callback
def async_on_subscribe_done(
hass: HomeAssistant,
topic: str,
qos: int,
on_subscribe_status: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Call on_subscribe_done when the matched subscription was completed.
If a subscription is already present the callback will call
on_subscribe_status directly.
Call the returned callback to stop and cleanup status monitoring.
"""
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
if (topic, qos) not in subscriptions:
return
hass.loop.call_soon(on_subscribe_status)
mqtt_data = hass.data[DATA_MQTT]
if (
mqtt_data.client.connected
and mqtt_data.client.is_active_subscription(topic)
and not mqtt_data.client.is_pending_subscription(topic)
):
hass.loop.call_soon(on_subscribe_status)
return async_dispatcher_connect(
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
)
@bind_hass @bind_hass
async def async_subscribe( async def async_subscribe(
hass: HomeAssistant, hass: HomeAssistant,
@@ -191,12 +228,32 @@ async def async_subscribe(
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
on_subscribe: CALLBACK_TYPE | None = None,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic. """Subscribe to an MQTT topic.
If the on_subcribe callback hook is set, it will be called once
when the subscription has been completed.
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding) handler: CALLBACK_TYPE | None = None
def _on_subscribe_done() -> None:
"""Call once when the subscription was completed."""
if TYPE_CHECKING:
assert on_subscribe is not None and handler is not None
handler()
on_subscribe()
subscription_handler = async_subscribe_internal(
hass, topic, msg_callback, qos, encoding
)
if on_subscribe is not None:
handler = async_on_subscribe_done(hass, topic, qos, _on_subscribe_done)
return subscription_handler
@callback @callback
@@ -640,12 +697,16 @@ class MQTT:
if fileno > -1: if fileno > -1:
self.loop.remove_writer(sock) self.loop.remove_writer(sock)
def _is_active_subscription(self, topic: str) -> bool: def is_active_subscription(self, topic: str) -> bool:
"""Check if a topic has an active subscription.""" """Check if a topic has an active subscription."""
return topic in self._simple_subscriptions or any( return topic in self._simple_subscriptions or any(
other.topic == topic for other in self._wildcard_subscriptions other.topic == topic for other in self._wildcard_subscriptions
) )
def is_pending_subscription(self, topic: str) -> bool:
"""Check if a topic has a pending subscription."""
return topic in self._pending_subscriptions
async def async_publish( async def async_publish(
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None: ) -> None:
@@ -899,7 +960,7 @@ class MQTT:
@callback @callback
def _async_unsubscribe(self, topic: str) -> None: def _async_unsubscribe(self, topic: str) -> None:
"""Unsubscribe from a topic.""" """Unsubscribe from a topic."""
if self._is_active_subscription(topic): if self.is_active_subscription(topic):
if self._max_qos[topic] == 0: if self._max_qos[topic] == 0:
return return
subs = self._matching_subscriptions(topic) subs = self._matching_subscriptions(topic)
@@ -963,6 +1024,7 @@ class MQTT:
self._last_subscribe = time.monotonic() self._last_subscribe = time.monotonic()
await self._async_wait_for_mid_or_raise(mid, result) await self._async_wait_for_mid_or_raise(mid, result)
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
async def _async_perform_unsubscribes(self) -> None: async def _async_perform_unsubscribes(self) -> None:
"""Perform pending MQTT client unsubscribes.""" """Perform pending MQTT client unsubscribes."""

View File

@@ -373,6 +373,7 @@ DOMAIN = "mqtt"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
MQTT_CONNECTION_STATE = "mqtt_connection_state" MQTT_CONNECTION_STATE = "mqtt_connection_state"
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
PAYLOAD_EMPTY_JSON = "{}" PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None" PAYLOAD_NONE = "None"

View File

@@ -282,6 +282,100 @@ async def test_subscribe_topic(
unsub() unsub()
async def test_status_subscription_done(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the on subscription status."""
await mqtt_mock_entry()
on_status = asyncio.Event()
on_status_calls: list[bool] = []
def _on_subscribe_status() -> None:
on_status.set()
on_status_calls.append(True)
subscribe_callback = await mqtt.async_subscribe(
hass, "test-topic", record_calls, qos=0
)
handler = mqtt.async_on_subscribe_done(
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
)
await on_status.wait()
assert ("test-topic", 0) in help_all_subscribe_calls(mqtt_client_mock)
await mqtt.async_publish(hass, "test-topic", "beer ready", 0)
handler()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
# Test as we have an existing subscription, test we get a callback
recorded_calls.clear()
on_status.clear()
handler = mqtt.async_on_subscribe_done(
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
)
assert len(on_status_calls) == 1
await on_status.wait()
assert len(on_status_calls) == 2
# cleanup
handler()
subscribe_callback()
async def test_subscribe_topic_with_subscribe_done(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mqtt_mock_entry()
on_status = asyncio.Event()
def _on_subscribe() -> None:
hass.async_create_task(mqtt.async_publish(hass, "test-topic", "beer ready", 0))
on_status.set()
# Start a first subscription
unsub1 = await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
)
await on_status.wait()
await hass.async_block_till_done()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
recorded_calls.clear()
# Start a second subscription to the same topic
on_status.clear()
unsub2 = await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
)
await on_status.wait()
await hass.async_block_till_done()
assert len(recorded_calls) == 2
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "beer ready"
assert recorded_calls[0].qos == 0
assert recorded_calls[1].topic == "test-topic"
assert recorded_calls[1].payload == "beer ready"
assert recorded_calls[1].qos == 0
unsub1()
unsub2()
@pytest.mark.usefixtures("mqtt_mock_entry") @pytest.mark.usefixtures("mqtt_mock_entry")
async def test_subscribe_topic_not_initialize( async def test_subscribe_topic_not_initialize(
hass: HomeAssistant, record_calls: MessageCallbackType hass: HomeAssistant, record_calls: MessageCallbackType
@@ -292,6 +386,16 @@ async def test_subscribe_topic_not_initialize(
): ):
await mqtt.async_subscribe(hass, "test-topic", record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
def _on_subscribe_callback() -> None:
pass
with pytest.raises(
HomeAssistantError, match=r".*make sure MQTT is set up correctly"
):
await mqtt.async_subscribe(
hass, "test-topic", record_calls, on_subscribe=_on_subscribe_callback
)
async def test_subscribe_mqtt_config_entry_disabled( async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType