mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Convert MQTT to use asyncio (#115910)
This commit is contained in:
parent
5a24690d79
commit
423544401e
@ -265,7 +265,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
conf: dict[str, Any]
|
||||
mqtt_data: MqttData
|
||||
|
||||
async def _setup_client() -> tuple[MqttData, dict[str, Any]]:
|
||||
async def _setup_client(
|
||||
client_available: asyncio.Future[bool],
|
||||
) -> tuple[MqttData, dict[str, Any]]:
|
||||
"""Set up the MQTT client."""
|
||||
# Fetch configuration
|
||||
conf = dict(entry.data)
|
||||
@ -294,7 +296,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
entry.add_update_listener(_async_config_entry_updated)
|
||||
)
|
||||
|
||||
await mqtt_data.client.async_connect()
|
||||
await mqtt_data.client.async_connect(client_available)
|
||||
return (mqtt_data, conf)
|
||||
|
||||
client_available: asyncio.Future[bool]
|
||||
@ -303,13 +305,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
else:
|
||||
client_available = hass.data[DATA_MQTT_AVAILABLE]
|
||||
|
||||
setup_ok: bool = False
|
||||
try:
|
||||
mqtt_data, conf = await _setup_client()
|
||||
setup_ok = True
|
||||
finally:
|
||||
if not client_available.done():
|
||||
client_available.set_result(setup_ok)
|
||||
mqtt_data, conf = await _setup_client(client_available)
|
||||
|
||||
async def async_publish_service(call: ServiceCall) -> None:
|
||||
"""Handle MQTT publish service calls."""
|
||||
|
@ -3,12 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from itertools import chain, groupby
|
||||
import logging
|
||||
from operator import attrgetter
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -35,7 +37,7 @@ from homeassistant.core import (
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.dispatcher import dispatcher_send
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import dt as dt_util
|
||||
@ -92,6 +94,9 @@ INITIAL_SUBSCRIBE_COOLDOWN = 1.0
|
||||
SUBSCRIBE_COOLDOWN = 0.1
|
||||
UNSUBSCRIBE_COOLDOWN = 0.1
|
||||
TIMEOUT_ACK = 10
|
||||
RECONNECT_INTERVAL_SECONDS = 10
|
||||
|
||||
SocketType = socket.socket | ssl.SSLSocket | Any
|
||||
|
||||
SubscribePayloadType = str | bytes # Only bytes if encoding is None
|
||||
|
||||
@ -258,7 +263,9 @@ class MqttClientSetup:
|
||||
# However, that feature is not mandatory so we generate our own.
|
||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
||||
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
||||
self._client = mqtt.Client(client_id, protocol=proto, transport=transport)
|
||||
self._client = mqtt.Client(
|
||||
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
|
||||
)
|
||||
|
||||
# Enable logging
|
||||
self._client.enable_logger()
|
||||
@ -404,12 +411,17 @@ class MQTT:
|
||||
self._ha_started = asyncio.Event()
|
||||
self._cleanup_on_unload: list[Callable[[], None]] = []
|
||||
|
||||
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._pending_operations: dict[int, asyncio.Event] = {}
|
||||
self._pending_operations_condition = asyncio.Condition()
|
||||
self._subscribe_debouncer = EnsureJobAfterCooldown(
|
||||
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
|
||||
)
|
||||
self._misc_task: asyncio.Task | None = None
|
||||
self._reconnect_task: asyncio.Task | None = None
|
||||
self._should_reconnect: bool = True
|
||||
self._available_future: asyncio.Future[bool] | None = None
|
||||
|
||||
self._max_qos: dict[str, int] = {} # topic, max qos
|
||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||
@ -456,25 +468,140 @@ class MQTT:
|
||||
while self._cleanup_on_unload:
|
||||
self._cleanup_on_unload.pop()()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_connect_in_executor(self) -> AsyncGenerator[None, None]:
|
||||
# While we are connecting in the executor we need to
|
||||
# handle on_socket_open and on_socket_register_write
|
||||
# in the executor as well.
|
||||
mqttc = self._mqttc
|
||||
try:
|
||||
mqttc.on_socket_open = self._on_socket_open
|
||||
mqttc.on_socket_register_write = self._on_socket_register_write
|
||||
yield
|
||||
finally:
|
||||
# Once the executor job is done, we can switch back to
|
||||
# handling these in the event loop.
|
||||
mqttc.on_socket_open = self._async_on_socket_open
|
||||
mqttc.on_socket_register_write = self._async_on_socket_register_write
|
||||
|
||||
def init_client(self) -> None:
|
||||
"""Initialize paho client."""
|
||||
self._mqttc = MqttClientSetup(self.conf).client
|
||||
self._mqttc.on_connect = self._mqtt_on_connect
|
||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||
self._mqttc.on_message = self._mqtt_on_message
|
||||
self._mqttc.on_publish = self._mqtt_on_callback
|
||||
self._mqttc.on_subscribe = self._mqtt_on_callback
|
||||
self._mqttc.on_unsubscribe = self._mqtt_on_callback
|
||||
mqttc = MqttClientSetup(self.conf).client
|
||||
# on_socket_unregister_write and _async_on_socket_close
|
||||
# are only ever called in the event loop
|
||||
mqttc.on_socket_close = self._async_on_socket_close
|
||||
mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write
|
||||
|
||||
# These will be called in the event loop
|
||||
mqttc.on_connect = self._async_mqtt_on_connect
|
||||
mqttc.on_disconnect = self._async_mqtt_on_disconnect
|
||||
mqttc.on_message = self._async_mqtt_on_message
|
||||
mqttc.on_publish = self._async_mqtt_on_callback
|
||||
mqttc.on_subscribe = self._async_mqtt_on_callback
|
||||
mqttc.on_unsubscribe = self._async_mqtt_on_callback
|
||||
|
||||
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
|
||||
will_message = PublishMessage(**will)
|
||||
self._mqttc.will_set(
|
||||
mqttc.will_set(
|
||||
topic=will_message.topic,
|
||||
payload=will_message.payload,
|
||||
qos=will_message.qos,
|
||||
retain=will_message.retain,
|
||||
)
|
||||
|
||||
self._mqttc = mqttc
|
||||
|
||||
async def _misc_loop(self) -> None:
|
||||
"""Start the MQTT client misc loop."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
while self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@callback
|
||||
def _async_reader_callback(self, client: mqtt.Client) -> None:
|
||||
"""Handle reading data from the socket."""
|
||||
if (status := client.loop_read()) != 0:
|
||||
self._async_on_disconnect(status)
|
||||
|
||||
@callback
|
||||
def _async_start_misc_loop(self) -> None:
|
||||
"""Start the misc loop."""
|
||||
if self._misc_task is None or self._misc_task.done():
|
||||
_LOGGER.debug("%s: Starting client misc loop", self.config_entry.title)
|
||||
self._misc_task = self.config_entry.async_create_background_task(
|
||||
self.hass, self._misc_loop(), name="mqtt misc loop"
|
||||
)
|
||||
|
||||
def _on_socket_open(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket open."""
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._async_on_socket_open, client, userdata, sock
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_on_socket_open(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket open."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.add_reader(sock, partial(self._async_reader_callback, client))
|
||||
self._async_start_misc_loop()
|
||||
|
||||
@callback
|
||||
def _async_on_socket_close(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket close."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno)
|
||||
# If socket close is called before the connect
|
||||
# result is set make sure the first connection result is set
|
||||
self._async_connection_result(False)
|
||||
if fileno > -1:
|
||||
self.loop.remove_reader(sock)
|
||||
if self._misc_task is not None and not self._misc_task.done():
|
||||
self._misc_task.cancel()
|
||||
|
||||
@callback
|
||||
def _async_writer_callback(self, client: mqtt.Client) -> None:
|
||||
"""Handle writing data to the socket."""
|
||||
if (status := client.loop_write()) != 0:
|
||||
self._async_on_disconnect(status)
|
||||
|
||||
def _on_socket_register_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Register the socket for writing."""
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._async_on_socket_register_write, client, None, sock
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_on_socket_register_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Register the socket for writing."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: register write %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.add_writer(sock, partial(self._async_writer_callback, client))
|
||||
|
||||
@callback
|
||||
def _async_on_socket_unregister_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Unregister the socket for writing."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.remove_writer(sock)
|
||||
|
||||
def _is_active_subscription(self, topic: str) -> bool:
|
||||
"""Check if a topic has an active subscription."""
|
||||
return topic in self._simple_subscriptions or any(
|
||||
@ -485,10 +612,7 @@ class MQTT:
|
||||
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
|
||||
) -> None:
|
||||
"""Publish a MQTT message."""
|
||||
async with self._paho_lock:
|
||||
msg_info = await self.hass.async_add_executor_job(
|
||||
self._mqttc.publish, topic, payload, qos, retain
|
||||
)
|
||||
msg_info = self._mqttc.publish(topic, payload, qos, retain)
|
||||
_LOGGER.debug(
|
||||
"Transmitting%s message on %s: '%s', mid: %s, qos: %s",
|
||||
" retained" if retain else "",
|
||||
@ -500,37 +624,71 @@ class MQTT:
|
||||
_raise_on_error(msg_info.rc)
|
||||
await self._wait_for_mid(msg_info.mid)
|
||||
|
||||
async def async_connect(self) -> None:
|
||||
async def async_connect(self, client_available: asyncio.Future[bool]) -> None:
|
||||
"""Connect to the host. Does not process messages yet."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
result: int | None = None
|
||||
self._available_future = client_available
|
||||
self._should_reconnect = True
|
||||
try:
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect,
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
||||
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
|
||||
)
|
||||
async with self._connection_lock, self._async_connect_in_executor():
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect,
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
||||
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
|
||||
)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
||||
self._async_connection_result(False)
|
||||
finally:
|
||||
if result is not None and result != 0:
|
||||
if result is not None:
|
||||
_LOGGER.error(
|
||||
"Failed to connect to MQTT server: %s",
|
||||
mqtt.error_string(result),
|
||||
)
|
||||
self._async_connection_result(False)
|
||||
|
||||
if result is not None and result != 0:
|
||||
_LOGGER.error(
|
||||
"Failed to connect to MQTT server: %s", mqtt.error_string(result)
|
||||
@callback
|
||||
def _async_connection_result(self, connected: bool) -> None:
|
||||
"""Handle a connection result."""
|
||||
if self._available_future and not self._available_future.done():
|
||||
self._available_future.set_result(connected)
|
||||
|
||||
if connected:
|
||||
self._async_cancel_reconnect()
|
||||
elif self._should_reconnect and not self._reconnect_task:
|
||||
self._reconnect_task = self.config_entry.async_create_background_task(
|
||||
self.hass, self._reconnect_loop(), "mqtt reconnect loop"
|
||||
)
|
||||
|
||||
self._mqttc.loop_start()
|
||||
@callback
|
||||
def _async_cancel_reconnect(self) -> None:
|
||||
"""Cancel the reconnect task."""
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
self._reconnect_task = None
|
||||
|
||||
async def _reconnect_loop(self) -> None:
|
||||
"""Reconnect to the MQTT server."""
|
||||
while True:
|
||||
if not self.connected:
|
||||
try:
|
||||
async with self._connection_lock, self._async_connect_in_executor():
|
||||
await self.hass.async_add_executor_job(self._mqttc.reconnect)
|
||||
except OSError as err:
|
||||
_LOGGER.debug(
|
||||
"Error re-connecting to MQTT server due to exception: %s", err
|
||||
)
|
||||
|
||||
await asyncio.sleep(RECONNECT_INTERVAL_SECONDS)
|
||||
|
||||
async def async_disconnect(self) -> None:
|
||||
"""Stop the MQTT client."""
|
||||
|
||||
def stop() -> None:
|
||||
"""Stop the MQTT client."""
|
||||
# Do not disconnect, we want the broker to always publish will
|
||||
self._mqttc.loop_stop()
|
||||
|
||||
def no_more_acks() -> bool:
|
||||
"""Return False if there are unprocessed ACKs."""
|
||||
return not any(not op.is_set() for op in self._pending_operations.values())
|
||||
@ -549,8 +707,10 @@ class MQTT:
|
||||
await self._pending_operations_condition.wait_for(no_more_acks)
|
||||
|
||||
# stop the MQTT loop
|
||||
async with self._paho_lock:
|
||||
await self.hass.async_add_executor_job(stop)
|
||||
async with self._connection_lock:
|
||||
self._should_reconnect = False
|
||||
self._async_cancel_reconnect()
|
||||
self._mqttc.disconnect()
|
||||
|
||||
@callback
|
||||
def async_restore_tracked_subscriptions(
|
||||
@ -689,11 +849,8 @@ class MQTT:
|
||||
subscriptions: dict[str, int] = self._pending_subscriptions
|
||||
self._pending_subscriptions = {}
|
||||
|
||||
async with self._paho_lock:
|
||||
subscription_list = list(subscriptions.items())
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.subscribe, subscription_list
|
||||
)
|
||||
subscription_list = list(subscriptions.items())
|
||||
result, mid = self._mqttc.subscribe(subscription_list)
|
||||
|
||||
for topic, qos in subscriptions.items():
|
||||
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
|
||||
@ -712,17 +869,15 @@ class MQTT:
|
||||
topics = list(self._pending_unsubscribes)
|
||||
self._pending_unsubscribes = set()
|
||||
|
||||
async with self._paho_lock:
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.unsubscribe, topics
|
||||
)
|
||||
result, mid = self._mqttc.unsubscribe(topics)
|
||||
_raise_on_error(result)
|
||||
for topic in topics:
|
||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
def _mqtt_on_connect(
|
||||
@callback
|
||||
def _async_mqtt_on_connect(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
@ -746,7 +901,7 @@ class MQTT:
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||
async_dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||
_LOGGER.info(
|
||||
"Connected to MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
@ -754,7 +909,7 @@ class MQTT:
|
||||
result_code,
|
||||
)
|
||||
|
||||
self.hass.create_task(self._async_resubscribe())
|
||||
self.hass.async_create_task(self._async_resubscribe())
|
||||
|
||||
if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH):
|
||||
|
||||
@ -771,13 +926,17 @@ class MQTT:
|
||||
)
|
||||
|
||||
birth_message = PublishMessage(**birth)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
publish_birth_message(birth_message), self.hass.loop
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
publish_birth_message(birth_message),
|
||||
name="mqtt birth message",
|
||||
)
|
||||
else:
|
||||
# Update subscribe cooldown period to a shorter time
|
||||
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
|
||||
|
||||
self._async_connection_result(True)
|
||||
|
||||
async def _async_resubscribe(self) -> None:
|
||||
"""Resubscribe on reconnect."""
|
||||
self._max_qos.clear()
|
||||
@ -796,16 +955,6 @@ class MQTT:
|
||||
)
|
||||
await self._async_perform_subscriptions()
|
||||
|
||||
def _mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
"""Message received callback."""
|
||||
# MQTT messages tend to be high volume,
|
||||
# and since they come in via a thread and need to be processed in the event loop,
|
||||
# we want to avoid hass.add_job since most of the time is spent calling
|
||||
# inspect to figure out how to run the callback.
|
||||
self.loop.call_soon_threadsafe(self._mqtt_handle_message, msg)
|
||||
|
||||
@lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
|
||||
subscriptions: list[Subscription] = []
|
||||
@ -819,7 +968,9 @@ class MQTT:
|
||||
return subscriptions
|
||||
|
||||
@callback
|
||||
def _mqtt_handle_message(self, msg: mqtt.MQTTMessage) -> None:
|
||||
def _async_mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
topic = msg.topic
|
||||
# msg.topic is a property that decodes the topic to a string
|
||||
# every time it is accessed. Save the result to avoid
|
||||
@ -878,7 +1029,8 @@ class MQTT:
|
||||
self.hass.async_run_hass_job(subscription.job, receive_msg)
|
||||
self._mqtt_data.state_write_requests.process_write_state_requests(msg)
|
||||
|
||||
def _mqtt_on_callback(
|
||||
@callback
|
||||
def _async_mqtt_on_callback(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
@ -890,7 +1042,7 @@ class MQTT:
|
||||
# The callback signature for on_unsubscribe is different from on_subscribe
|
||||
# see https://github.com/eclipse/paho.mqtt.python/issues/687
|
||||
# properties and reasoncodes are not used in Home Assistant
|
||||
self.hass.create_task(self._mqtt_handle_mid(mid))
|
||||
self.hass.async_create_task(self._mqtt_handle_mid(mid))
|
||||
|
||||
async def _mqtt_handle_mid(self, mid: int) -> None:
|
||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||
@ -906,7 +1058,8 @@ class MQTT:
|
||||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
|
||||
def _mqtt_on_disconnect(
|
||||
@callback
|
||||
def _async_mqtt_on_disconnect(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
@ -914,8 +1067,19 @@ class MQTT:
|
||||
properties: mqtt.Properties | None = None,
|
||||
) -> None:
|
||||
"""Disconnected callback."""
|
||||
self._async_on_disconnect(result_code)
|
||||
|
||||
@callback
|
||||
def _async_on_disconnect(self, result_code: int) -> None:
|
||||
if not self.connected:
|
||||
# This function is re-entrant and may be called multiple times
|
||||
# when there is a broken pipe error.
|
||||
return
|
||||
# If disconnect is called before the connect
|
||||
# result is set make sure the first connection result is set
|
||||
self._async_connection_result(False)
|
||||
self.connected = False
|
||||
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
async_dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
_LOGGER.warning(
|
||||
"Disconnected from MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
|
@ -452,7 +452,7 @@ def async_fire_mqtt_message(
|
||||
|
||||
mqtt_data: MqttData = hass.data["mqtt"]
|
||||
assert mqtt_data.client
|
||||
mqtt_data.client._mqtt_handle_message(msg)
|
||||
mqtt_data.client._async_mqtt_on_message(Mock(), None, msg)
|
||||
|
||||
|
||||
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
|
||||
|
@ -4,17 +4,22 @@ import asyncio
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Any, TypedDict
|
||||
from unittest.mock import ANY, MagicMock, call, mock_open, patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import paho.mqtt.client as paho_mqtt
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.mqtt import debug_info
|
||||
from homeassistant.components.mqtt.client import EnsureJobAfterCooldown
|
||||
from homeassistant.components.mqtt.client import (
|
||||
RECONNECT_INTERVAL_SECONDS,
|
||||
EnsureJobAfterCooldown,
|
||||
)
|
||||
from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA
|
||||
from homeassistant.components.mqtt.models import (
|
||||
MessageCallbackType,
|
||||
@ -146,7 +151,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
|
||||
hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_client_mock.loop_stop.call_count == 1
|
||||
assert mqtt_client_mock.disconnect.call_count == 1
|
||||
|
||||
|
||||
async def test_mqtt_await_ack_at_disconnect(
|
||||
@ -161,8 +166,14 @@ async def test_mqtt_await_ack_at_disconnect(
|
||||
rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().connect = MagicMock(return_value=0)
|
||||
mock_client().publish = MagicMock(return_value=FakeInfo())
|
||||
mqtt_client = mock_client.return_value
|
||||
mqtt_client.connect = MagicMock(
|
||||
return_value=0,
|
||||
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
|
||||
mqtt_client.on_connect, mqtt_client, None, 0, 0, 0
|
||||
),
|
||||
)
|
||||
mqtt_client.publish = MagicMock(return_value=FakeInfo())
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||
@ -1669,6 +1680,7 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
|
||||
the subscribe cool down period has ended.
|
||||
"""
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
mqtt_client_mock.subscribe.reset_mock()
|
||||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
|
||||
@ -1925,6 +1937,7 @@ async def test_canceling_debouncer_on_shutdown(
|
||||
"""Test canceling the debouncer when HA shuts down."""
|
||||
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
mqtt_client_mock.subscribe.reset_mock()
|
||||
|
||||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
@ -2008,7 +2021,7 @@ async def test_initial_setup_logs_error(
|
||||
"""Test for setup failure if initial client connection fails."""
|
||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||
entry.add_to_hass(hass)
|
||||
mqtt_client_mock.connect.return_value = 1
|
||||
mqtt_client_mock.connect.side_effect = MagicMock(return_value=1)
|
||||
try:
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
except HomeAssistantError:
|
||||
@ -2230,7 +2243,12 @@ async def test_handle_mqtt_timeout_on_callback(
|
||||
mock_client = mock_client.return_value
|
||||
mock_client.publish.return_value = FakeInfo()
|
||||
mock_client.subscribe.side_effect = _mock_ack
|
||||
mock_client.connect.return_value = 0
|
||||
mock_client.connect = MagicMock(
|
||||
return_value=0,
|
||||
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
|
||||
mock_client.on_connect, mock_client, None, 0, 0, 0
|
||||
),
|
||||
)
|
||||
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}
|
||||
@ -4144,3 +4162,179 @@ async def test_multi_platform_discovery(
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||
async def test_auto_reconnect(
|
||||
hass: HomeAssistant,
|
||||
mqtt_client_mock: MqttMockPahoClient,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test reconnection is automatically done."""
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.connected is True
|
||||
mqtt_client_mock.reconnect.reset_mock()
|
||||
|
||||
mqtt_client_mock.disconnect()
|
||||
mqtt_client_mock.on_disconnect(None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_client_mock.reconnect.side_effect = OSError("foo")
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mqtt_client_mock.reconnect.mock_calls) == 1
|
||||
assert "Error re-connecting to MQTT server due to exception: foo" in caplog.text
|
||||
|
||||
mqtt_client_mock.reconnect.side_effect = None
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
|
||||
mqtt_client_mock.disconnect()
|
||||
mqtt_client_mock.on_disconnect(None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
# Should not reconnect after stop
|
||||
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||
async def test_server_sock_connect_and_disconnect(
|
||||
hass: HomeAssistant,
|
||||
mqtt_client_mock: MqttMockPahoClient,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
calls: list[ReceiveMessage],
|
||||
record_calls: MessageCallbackType,
|
||||
) -> None:
|
||||
"""Test handling the socket connected and disconnected."""
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.connected is True
|
||||
|
||||
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
|
||||
|
||||
client, server = socket.socketpair(
|
||||
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
|
||||
)
|
||||
client.setblocking(False)
|
||||
server.setblocking(False)
|
||||
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
server.close() # mock the server closing the connection on us
|
||||
|
||||
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
|
||||
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
|
||||
mqtt_client_mock.on_socket_unregister_write(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.on_socket_close(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.on_disconnect(mqtt_client_mock, None, client)
|
||||
await hass.async_block_till_done()
|
||||
unsub()
|
||||
|
||||
# Should have failed
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||
async def test_client_sock_failure_after_connect(
|
||||
hass: HomeAssistant,
|
||||
mqtt_client_mock: MqttMockPahoClient,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
calls: list[ReceiveMessage],
|
||||
record_calls: MessageCallbackType,
|
||||
) -> None:
|
||||
"""Test handling the socket connected and disconnected."""
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.connected is True
|
||||
|
||||
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
|
||||
|
||||
client, server = socket.socketpair(
|
||||
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
|
||||
)
|
||||
client.setblocking(False)
|
||||
server.setblocking(False)
|
||||
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.on_socket_register_writer(mqtt_client_mock, None, client)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_client_mock.loop_write.side_effect = OSError("foo")
|
||||
client.close() # close the client socket out from under the client
|
||||
|
||||
assert mqtt_mock.connected is True
|
||||
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
unsub()
|
||||
# Should have failed
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||
async def test_loop_write_failure(
|
||||
hass: HomeAssistant,
|
||||
mqtt_client_mock: MqttMockPahoClient,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test handling the socket connected and disconnected."""
|
||||
mqtt_mock = await mqtt_mock_entry()
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.connected is True
|
||||
|
||||
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
|
||||
|
||||
client, server = socket.socketpair(
|
||||
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
|
||||
)
|
||||
client.setblocking(False)
|
||||
server.setblocking(False)
|
||||
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
|
||||
mqtt_client_mock.loop_write.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
|
||||
mqtt_client_mock.loop_read.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
|
||||
|
||||
# Fill up the outgoing buffer to ensure that loop_write
|
||||
# and loop_read are called that next time control is
|
||||
# returned to the event loop
|
||||
try:
|
||||
for _ in range(1000):
|
||||
server.send(b"long" * 100)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
|
||||
server.close()
|
||||
# Once for the reader callback
|
||||
await hass.async_block_till_done()
|
||||
# Another for the writer callback
|
||||
await hass.async_block_till_done()
|
||||
# Final for the disconnect callback
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Disconnected from MQTT server mock-broker:1883 (7)" in caplog.text
|
||||
|
@ -163,7 +163,7 @@ async def help_test_availability_when_connection_lost(
|
||||
|
||||
# Disconnected from MQTT server -> state changed to unavailable
|
||||
mqtt_mock.connected = False
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
|
||||
mqtt_client_mock.on_disconnect(None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
@ -172,7 +172,7 @@ async def help_test_availability_when_connection_lost(
|
||||
|
||||
# Reconnected to MQTT server -> state still unavailable
|
||||
mqtt_mock.connected = True
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
|
||||
mqtt_client_mock.on_connect(None, None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
@ -224,7 +224,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
|
||||
|
||||
# Disconnected from MQTT server -> state changed to unavailable
|
||||
mqtt_mock.connected = False
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
|
||||
mqtt_client_mock.on_disconnect(None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
@ -233,7 +233,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
|
||||
|
||||
# Reconnected to MQTT server -> state no longer unavailable
|
||||
mqtt_mock.connected = True
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
|
||||
mqtt_client_mock.on_connect(None, None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
@ -476,7 +476,7 @@ async def help_test_availability_poll_state(
|
||||
|
||||
# Disconnected from MQTT server
|
||||
mqtt_mock.connected = False
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
|
||||
mqtt_client_mock.on_disconnect(None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
@ -484,7 +484,7 @@ async def help_test_availability_poll_state(
|
||||
|
||||
# Reconnected to MQTT server
|
||||
mqtt_mock.connected = True
|
||||
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
|
||||
mqtt_client_mock.on_connect(None, None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
@ -904,26 +904,45 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
|
||||
self.rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
|
||||
# callbacks to simulate the behavior of the real MQTT client which will
|
||||
# not be synchronous.
|
||||
|
||||
@ha.callback
|
||||
def _async_fire_mqtt_message(topic, payload, qos, retain):
|
||||
async_fire_mqtt_message(hass, topic, payload, qos, retain)
|
||||
mid = get_mid()
|
||||
mock_client.on_publish(0, 0, mid)
|
||||
hass.loop.call_soon(mock_client.on_publish, 0, 0, mid)
|
||||
return FakeInfo(mid)
|
||||
|
||||
def _subscribe(topic, qos=0):
|
||||
mid = get_mid()
|
||||
mock_client.on_subscribe(0, 0, mid)
|
||||
hass.loop.call_soon(mock_client.on_subscribe, 0, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
def _unsubscribe(topic):
|
||||
mid = get_mid()
|
||||
mock_client.on_unsubscribe(0, 0, mid)
|
||||
hass.loop.call_soon(mock_client.on_unsubscribe, 0, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
def _connect(*args, **kwargs):
|
||||
# Connect always calls reconnect once, but we
|
||||
# mock it out so we call reconnect to simulate
|
||||
# the behavior.
|
||||
mock_client.reconnect()
|
||||
hass.loop.call_soon_threadsafe(
|
||||
mock_client.on_connect, mock_client, None, 0, 0, 0
|
||||
)
|
||||
mock_client.on_socket_open(
|
||||
mock_client, None, Mock(fileno=Mock(return_value=-1))
|
||||
)
|
||||
mock_client.on_socket_register_write(
|
||||
mock_client, None, Mock(fileno=Mock(return_value=-1))
|
||||
)
|
||||
return 0
|
||||
|
||||
mock_client = mock_client.return_value
|
||||
mock_client.connect.return_value = 0
|
||||
mock_client.connect.side_effect = _connect
|
||||
mock_client.subscribe.side_effect = _subscribe
|
||||
mock_client.unsubscribe.side_effect = _unsubscribe
|
||||
mock_client.publish.side_effect = _async_fire_mqtt_message
|
||||
@ -985,6 +1004,7 @@ async def _mqtt_mock_entry(
|
||||
|
||||
# connected set to True to get a more realistic behavior when subscribing
|
||||
mock_mqtt_instance.connected = True
|
||||
mqtt_client_mock.on_connect(mqtt_client_mock, None, 0, 0, 0)
|
||||
|
||||
async_dispatcher_send(hass, mqtt.MQTT_CONNECTED)
|
||||
await hass.async_block_till_done()
|
||||
|
Loading…
x
Reference in New Issue
Block a user