mirror of
https://github.com/home-assistant/core.git
synced 2025-11-13 13:00:11 +00:00
Compare commits
3 Commits
copilot/ad
...
mqtt-subsc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a784ec6454 | ||
|
|
8c8b1df11f | ||
|
|
aabcff9653 |
@@ -40,6 +40,7 @@ from homeassistant.util.async_ import create_eager_task
|
|||||||
from . import debug_info, discovery
|
from . import debug_info, discovery
|
||||||
from .client import (
|
from .client import (
|
||||||
MQTT,
|
MQTT,
|
||||||
|
async_on_subscribe_done,
|
||||||
async_publish,
|
async_publish,
|
||||||
async_subscribe,
|
async_subscribe,
|
||||||
async_subscribe_internal,
|
async_subscribe_internal,
|
||||||
@@ -163,6 +164,7 @@ __all__ = [
|
|||||||
"async_create_certificate_temp_files",
|
"async_create_certificate_temp_files",
|
||||||
"async_forward_entry_setup_and_setup_discovery",
|
"async_forward_entry_setup_and_setup_discovery",
|
||||||
"async_migrate_entry",
|
"async_migrate_entry",
|
||||||
|
"async_on_subscribe_done",
|
||||||
"async_prepare_subscribe_topics",
|
"async_prepare_subscribe_topics",
|
||||||
"async_publish",
|
"async_publish",
|
||||||
"async_remove_config_entry_device",
|
"async_remove_config_entry_device",
|
||||||
|
|||||||
@@ -38,7 +38,10 @@ from homeassistant.core import (
|
|||||||
get_hassjob_callable_job_type,
|
get_hassjob_callable_job_type,
|
||||||
)
|
)
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import (
|
||||||
|
async_dispatcher_connect,
|
||||||
|
async_dispatcher_send,
|
||||||
|
)
|
||||||
from homeassistant.helpers.importlib import async_import_module
|
from homeassistant.helpers.importlib import async_import_module
|
||||||
from homeassistant.helpers.start import async_at_started
|
from homeassistant.helpers.start import async_at_started
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
@@ -71,6 +74,7 @@ from .const import (
|
|||||||
DEFAULT_WS_PATH,
|
DEFAULT_WS_PATH,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
MQTT_CONNECTION_STATE,
|
MQTT_CONNECTION_STATE,
|
||||||
|
MQTT_PROCESSED_SUBSCRIPTIONS,
|
||||||
PROTOCOL_5,
|
PROTOCOL_5,
|
||||||
PROTOCOL_31,
|
PROTOCOL_31,
|
||||||
TRANSPORT_WEBSOCKETS,
|
TRANSPORT_WEBSOCKETS,
|
||||||
@@ -109,6 +113,7 @@ INITIAL_SUBSCRIBE_COOLDOWN = 0.5
|
|||||||
SUBSCRIBE_COOLDOWN = 0.1
|
SUBSCRIBE_COOLDOWN = 0.1
|
||||||
UNSUBSCRIBE_COOLDOWN = 0.1
|
UNSUBSCRIBE_COOLDOWN = 0.1
|
||||||
TIMEOUT_ACK = 10
|
TIMEOUT_ACK = 10
|
||||||
|
SUBSCRIBE_TIMEOUT = 10
|
||||||
RECONNECT_INTERVAL_SECONDS = 10
|
RECONNECT_INTERVAL_SECONDS = 10
|
||||||
|
|
||||||
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
|
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
|
||||||
@@ -184,6 +189,38 @@ async def async_publish(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_on_subscribe_done(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
topic: str,
|
||||||
|
qos: int,
|
||||||
|
on_subscribe_status: CALLBACK_TYPE,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Call on_subscribe_done when the matched subscription was completed.
|
||||||
|
|
||||||
|
If a subscription is already present the callback will call
|
||||||
|
on_subscribe_status directly.
|
||||||
|
Call the returned callback to stop and cleanup status monitoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
|
||||||
|
if (topic, qos) not in subscriptions:
|
||||||
|
return
|
||||||
|
hass.loop.call_soon(on_subscribe_status)
|
||||||
|
|
||||||
|
mqtt_data = hass.data[DATA_MQTT]
|
||||||
|
if (
|
||||||
|
mqtt_data.client.connected
|
||||||
|
and mqtt_data.client.is_active_subscription(topic)
|
||||||
|
and not mqtt_data.client.is_pending_subscription(topic)
|
||||||
|
):
|
||||||
|
hass.loop.call_soon(on_subscribe_status)
|
||||||
|
|
||||||
|
return async_dispatcher_connect(
|
||||||
|
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_subscribe(
|
async def async_subscribe(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@@ -191,12 +228,32 @@ async def async_subscribe(
|
|||||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
qos: int = DEFAULT_QOS,
|
qos: int = DEFAULT_QOS,
|
||||||
encoding: str | None = DEFAULT_ENCODING,
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
|
on_subscribe: CALLBACK_TYPE | None = None,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Subscribe to an MQTT topic.
|
"""Subscribe to an MQTT topic.
|
||||||
|
|
||||||
|
If the on_subcribe callback hook is set, it will be called once
|
||||||
|
when the subscription has been completed.
|
||||||
|
|
||||||
Call the return value to unsubscribe.
|
Call the return value to unsubscribe.
|
||||||
"""
|
"""
|
||||||
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
|
handler: CALLBACK_TYPE | None = None
|
||||||
|
|
||||||
|
def _on_subscribe_done() -> None:
|
||||||
|
"""Call once when the subscription was completed."""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert on_subscribe is not None and handler is not None
|
||||||
|
|
||||||
|
handler()
|
||||||
|
on_subscribe()
|
||||||
|
|
||||||
|
subscription_handler = async_subscribe_internal(
|
||||||
|
hass, topic, msg_callback, qos, encoding
|
||||||
|
)
|
||||||
|
if on_subscribe is not None:
|
||||||
|
handler = async_on_subscribe_done(hass, topic, qos, _on_subscribe_done)
|
||||||
|
|
||||||
|
return subscription_handler
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@@ -640,12 +697,16 @@ class MQTT:
|
|||||||
if fileno > -1:
|
if fileno > -1:
|
||||||
self.loop.remove_writer(sock)
|
self.loop.remove_writer(sock)
|
||||||
|
|
||||||
def _is_active_subscription(self, topic: str) -> bool:
|
def is_active_subscription(self, topic: str) -> bool:
|
||||||
"""Check if a topic has an active subscription."""
|
"""Check if a topic has an active subscription."""
|
||||||
return topic in self._simple_subscriptions or any(
|
return topic in self._simple_subscriptions or any(
|
||||||
other.topic == topic for other in self._wildcard_subscriptions
|
other.topic == topic for other in self._wildcard_subscriptions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_pending_subscription(self, topic: str) -> bool:
|
||||||
|
"""Check if a topic has a pending subscription."""
|
||||||
|
return topic in self._pending_subscriptions
|
||||||
|
|
||||||
async def async_publish(
|
async def async_publish(
|
||||||
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
|
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -899,7 +960,7 @@ class MQTT:
|
|||||||
@callback
|
@callback
|
||||||
def _async_unsubscribe(self, topic: str) -> None:
|
def _async_unsubscribe(self, topic: str) -> None:
|
||||||
"""Unsubscribe from a topic."""
|
"""Unsubscribe from a topic."""
|
||||||
if self._is_active_subscription(topic):
|
if self.is_active_subscription(topic):
|
||||||
if self._max_qos[topic] == 0:
|
if self._max_qos[topic] == 0:
|
||||||
return
|
return
|
||||||
subs = self._matching_subscriptions(topic)
|
subs = self._matching_subscriptions(topic)
|
||||||
@@ -963,6 +1024,7 @@ class MQTT:
|
|||||||
self._last_subscribe = time.monotonic()
|
self._last_subscribe = time.monotonic()
|
||||||
|
|
||||||
await self._async_wait_for_mid_or_raise(mid, result)
|
await self._async_wait_for_mid_or_raise(mid, result)
|
||||||
|
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
|
||||||
|
|
||||||
async def _async_perform_unsubscribes(self) -> None:
|
async def _async_perform_unsubscribes(self) -> None:
|
||||||
"""Perform pending MQTT client unsubscribes."""
|
"""Perform pending MQTT client unsubscribes."""
|
||||||
|
|||||||
@@ -373,6 +373,7 @@ DOMAIN = "mqtt"
|
|||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
MQTT_CONNECTION_STATE = "mqtt_connection_state"
|
MQTT_CONNECTION_STATE = "mqtt_connection_state"
|
||||||
|
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
|
||||||
|
|
||||||
PAYLOAD_EMPTY_JSON = "{}"
|
PAYLOAD_EMPTY_JSON = "{}"
|
||||||
PAYLOAD_NONE = "None"
|
PAYLOAD_NONE = "None"
|
||||||
|
|||||||
@@ -282,6 +282,100 @@ async def test_subscribe_topic(
|
|||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_status_subscription_done(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
recorded_calls: list[ReceiveMessage],
|
||||||
|
record_calls: MessageCallbackType,
|
||||||
|
) -> None:
|
||||||
|
"""Test the on subscription status."""
|
||||||
|
await mqtt_mock_entry()
|
||||||
|
|
||||||
|
on_status = asyncio.Event()
|
||||||
|
on_status_calls: list[bool] = []
|
||||||
|
|
||||||
|
def _on_subscribe_status() -> None:
|
||||||
|
on_status.set()
|
||||||
|
on_status_calls.append(True)
|
||||||
|
|
||||||
|
subscribe_callback = await mqtt.async_subscribe(
|
||||||
|
hass, "test-topic", record_calls, qos=0
|
||||||
|
)
|
||||||
|
handler = mqtt.async_on_subscribe_done(
|
||||||
|
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
|
||||||
|
)
|
||||||
|
await on_status.wait()
|
||||||
|
assert ("test-topic", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
|
|
||||||
|
await mqtt.async_publish(hass, "test-topic", "beer ready", 0)
|
||||||
|
handler()
|
||||||
|
assert len(recorded_calls) == 1
|
||||||
|
assert recorded_calls[0].topic == "test-topic"
|
||||||
|
assert recorded_calls[0].payload == "beer ready"
|
||||||
|
assert recorded_calls[0].qos == 0
|
||||||
|
|
||||||
|
# Test as we have an existing subscription, test we get a callback
|
||||||
|
recorded_calls.clear()
|
||||||
|
on_status.clear()
|
||||||
|
handler = mqtt.async_on_subscribe_done(
|
||||||
|
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
|
||||||
|
)
|
||||||
|
assert len(on_status_calls) == 1
|
||||||
|
await on_status.wait()
|
||||||
|
assert len(on_status_calls) == 2
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
handler()
|
||||||
|
subscribe_callback()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_topic_with_subscribe_done(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
recorded_calls: list[ReceiveMessage],
|
||||||
|
record_calls: MessageCallbackType,
|
||||||
|
) -> None:
|
||||||
|
"""Test the subscription of a topic."""
|
||||||
|
await mqtt_mock_entry()
|
||||||
|
|
||||||
|
on_status = asyncio.Event()
|
||||||
|
|
||||||
|
def _on_subscribe() -> None:
|
||||||
|
hass.async_create_task(mqtt.async_publish(hass, "test-topic", "beer ready", 0))
|
||||||
|
on_status.set()
|
||||||
|
|
||||||
|
# Start a first subscription
|
||||||
|
unsub1 = await mqtt.async_subscribe(
|
||||||
|
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
|
||||||
|
)
|
||||||
|
await on_status.wait()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(recorded_calls) == 1
|
||||||
|
assert recorded_calls[0].topic == "test-topic"
|
||||||
|
assert recorded_calls[0].payload == "beer ready"
|
||||||
|
assert recorded_calls[0].qos == 0
|
||||||
|
recorded_calls.clear()
|
||||||
|
|
||||||
|
# Start a second subscription to the same topic
|
||||||
|
on_status.clear()
|
||||||
|
unsub2 = await mqtt.async_subscribe(
|
||||||
|
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
|
||||||
|
)
|
||||||
|
await on_status.wait()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(recorded_calls) == 2
|
||||||
|
assert recorded_calls[0].topic == "test-topic"
|
||||||
|
assert recorded_calls[0].payload == "beer ready"
|
||||||
|
assert recorded_calls[0].qos == 0
|
||||||
|
assert recorded_calls[1].topic == "test-topic"
|
||||||
|
assert recorded_calls[1].payload == "beer ready"
|
||||||
|
assert recorded_calls[1].qos == 0
|
||||||
|
|
||||||
|
unsub1()
|
||||||
|
unsub2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mqtt_mock_entry")
|
@pytest.mark.usefixtures("mqtt_mock_entry")
|
||||||
async def test_subscribe_topic_not_initialize(
|
async def test_subscribe_topic_not_initialize(
|
||||||
hass: HomeAssistant, record_calls: MessageCallbackType
|
hass: HomeAssistant, record_calls: MessageCallbackType
|
||||||
@@ -292,6 +386,16 @@ async def test_subscribe_topic_not_initialize(
|
|||||||
):
|
):
|
||||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
|
||||||
|
def _on_subscribe_callback() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match=r".*make sure MQTT is set up correctly"
|
||||||
|
):
|
||||||
|
await mqtt.async_subscribe(
|
||||||
|
hass, "test-topic", record_calls, on_subscribe=_on_subscribe_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_subscribe_mqtt_config_entry_disabled(
|
async def test_subscribe_mqtt_config_entry_disabled(
|
||||||
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType
|
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType
|
||||||
|
|||||||
Reference in New Issue
Block a user