diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 28cb7d0944b..cc1ae3ddce1 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -265,7 +265,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: conf: dict[str, Any] mqtt_data: MqttData - async def _setup_client() -> tuple[MqttData, dict[str, Any]]: + async def _setup_client( + client_available: asyncio.Future[bool], + ) -> tuple[MqttData, dict[str, Any]]: """Set up the MQTT client.""" # Fetch configuration conf = dict(entry.data) @@ -294,7 +296,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: entry.add_update_listener(_async_config_entry_updated) ) - await mqtt_data.client.async_connect() + await mqtt_data.client.async_connect(client_available) return (mqtt_data, conf) client_available: asyncio.Future[bool] @@ -303,13 +305,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: else: client_available = hass.data[DATA_MQTT_AVAILABLE] - setup_ok: bool = False - try: - mqtt_data, conf = await _setup_client() - setup_ok = True - finally: - if not client_available.done(): - client_available.set_result(setup_ok) + mqtt_data, conf = await _setup_client(client_available) async def async_publish_service(call: ServiceCall) -> None: """Handle MQTT publish service calls.""" diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 978123e169c..021ecf1cc36 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -3,12 +3,14 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Iterable +from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable +import contextlib from dataclasses import dataclass -from functools import lru_cache +from functools import lru_cache, partial from itertools import chain, groupby import logging from operator import attrgetter +import socket import ssl import time from typing import TYPE_CHECKING, Any @@ -35,7 +37,7 @@ from homeassistant.core import ( callback, ) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.dispatcher import dispatcher_send +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util @@ -92,6 +94,9 @@ INITIAL_SUBSCRIBE_COOLDOWN = 1.0 SUBSCRIBE_COOLDOWN = 0.1 UNSUBSCRIBE_COOLDOWN = 0.1 TIMEOUT_ACK = 10 +RECONNECT_INTERVAL_SECONDS = 10 + +SocketType = socket.socket | ssl.SSLSocket | Any SubscribePayloadType = str | bytes # Only bytes if encoding is None @@ -258,7 +263,9 @@ class MqttClientSetup: # However, that feature is not mandatory so we generate our own. client_id = mqtt.base62(uuid.uuid4().int, padding=22) transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT) - self._client = mqtt.Client(client_id, protocol=proto, transport=transport) + self._client = mqtt.Client( + client_id, protocol=proto, transport=transport, reconnect_on_failure=False + ) # Enable logging self._client.enable_logger() @@ -404,12 +411,17 @@ class MQTT: self._ha_started = asyncio.Event() self._cleanup_on_unload: list[Callable[[], None]] = [] - self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client + self._connection_lock = asyncio.Lock() self._pending_operations: dict[int, asyncio.Event] = {} self._pending_operations_condition = asyncio.Condition() self._subscribe_debouncer = EnsureJobAfterCooldown( INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions ) + self._misc_task: asyncio.Task | None = None + self._reconnect_task: asyncio.Task | None = None + self._should_reconnect: bool = True + self._available_future: asyncio.Future[bool] | None = None + self._max_qos: dict[str, int] = {} # topic, max qos self._pending_subscriptions: dict[str, int] = {} # topic, qos self._unsubscribe_debouncer = EnsureJobAfterCooldown( @@ -456,25 +468,140 @@ class MQTT: while self._cleanup_on_unload: self._cleanup_on_unload.pop()() + @contextlib.asynccontextmanager + async def _async_connect_in_executor(self) -> AsyncGenerator[None, None]: + # While we are connecting in the executor we need to + # handle on_socket_open and on_socket_register_write + # in the executor as well. + mqttc = self._mqttc + try: + mqttc.on_socket_open = self._on_socket_open + mqttc.on_socket_register_write = self._on_socket_register_write + yield + finally: + # Once the executor job is done, we can switch back to + # handling these in the event loop. + mqttc.on_socket_open = self._async_on_socket_open + mqttc.on_socket_register_write = self._async_on_socket_register_write + def init_client(self) -> None: """Initialize paho client.""" - self._mqttc = MqttClientSetup(self.conf).client - self._mqttc.on_connect = self._mqtt_on_connect - self._mqttc.on_disconnect = self._mqtt_on_disconnect - self._mqttc.on_message = self._mqtt_on_message - self._mqttc.on_publish = self._mqtt_on_callback - self._mqttc.on_subscribe = self._mqtt_on_callback - self._mqttc.on_unsubscribe = self._mqtt_on_callback + mqttc = MqttClientSetup(self.conf).client + # on_socket_unregister_write and _async_on_socket_close + # are only ever called in the event loop + mqttc.on_socket_close = self._async_on_socket_close + mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write + + # These will be called in the event loop + mqttc.on_connect = self._async_mqtt_on_connect + mqttc.on_disconnect = self._async_mqtt_on_disconnect + mqttc.on_message = self._async_mqtt_on_message + mqttc.on_publish = self._async_mqtt_on_callback + mqttc.on_subscribe = self._async_mqtt_on_callback + mqttc.on_unsubscribe = self._async_mqtt_on_callback if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL): will_message = PublishMessage(**will) - self._mqttc.will_set( + mqttc.will_set( topic=will_message.topic, payload=will_message.payload, qos=will_message.qos, retain=will_message.retain, ) + self._mqttc = mqttc + + async def _misc_loop(self) -> None: + """Start the MQTT client misc loop.""" + # pylint: disable=import-outside-toplevel + import paho.mqtt.client as mqtt + + while self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS: + await asyncio.sleep(1) + + @callback + def _async_reader_callback(self, client: mqtt.Client) -> None: + """Handle reading data from the socket.""" + if (status := client.loop_read()) != 0: + self._async_on_disconnect(status) + + @callback + def _async_start_misc_loop(self) -> None: + """Start the misc loop.""" + if self._misc_task is None or self._misc_task.done(): + _LOGGER.debug("%s: Starting client misc loop", self.config_entry.title) + self._misc_task = self.config_entry.async_create_background_task( + self.hass, self._misc_loop(), name="mqtt misc loop" + ) + + def _on_socket_open( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Handle socket open.""" + self.loop.call_soon_threadsafe( + self._async_on_socket_open, client, userdata, sock + ) + + @callback + def _async_on_socket_open( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Handle socket open.""" + fileno = sock.fileno() + _LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno) + if fileno > -1: + self.loop.add_reader(sock, partial(self._async_reader_callback, client)) + self._async_start_misc_loop() + + @callback + def _async_on_socket_close( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Handle socket close.""" + fileno = sock.fileno() + _LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno) + # If socket close is called before the connect + # result is set make sure the first connection result is set + self._async_connection_result(False) + if fileno > -1: + self.loop.remove_reader(sock) + if self._misc_task is not None and not self._misc_task.done(): + self._misc_task.cancel() + + @callback + def _async_writer_callback(self, client: mqtt.Client) -> None: + """Handle writing data to the socket.""" + if (status := client.loop_write()) != 0: + self._async_on_disconnect(status) + + def _on_socket_register_write( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Register the socket for writing.""" + self.loop.call_soon_threadsafe( + self._async_on_socket_register_write, client, None, sock + ) + + @callback + def _async_on_socket_register_write( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Register the socket for writing.""" + fileno = sock.fileno() + _LOGGER.debug("%s: register write %s", self.config_entry.title, fileno) + if fileno > -1: + self.loop.add_writer(sock, partial(self._async_writer_callback, client)) + + @callback + def _async_on_socket_unregister_write( + self, client: mqtt.Client, userdata: Any, sock: SocketType + ) -> None: + """Unregister the socket for writing.""" + fileno = sock.fileno() + _LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno) + if fileno > -1: + self.loop.remove_writer(sock) + def _is_active_subscription(self, topic: str) -> bool: """Check if a topic has an active subscription.""" return topic in self._simple_subscriptions or any( @@ -485,10 +612,7 @@ class MQTT: self, topic: str, payload: PublishPayloadType, qos: int, retain: bool ) -> None: """Publish a MQTT message.""" - async with self._paho_lock: - msg_info = await self.hass.async_add_executor_job( - self._mqttc.publish, topic, payload, qos, retain - ) + msg_info = self._mqttc.publish(topic, payload, qos, retain) _LOGGER.debug( "Transmitting%s message on %s: '%s', mid: %s, qos: %s", " retained" if retain else "", @@ -500,37 +624,71 @@ class MQTT: _raise_on_error(msg_info.rc) await self._wait_for_mid(msg_info.mid) - async def async_connect(self) -> None: + async def async_connect(self, client_available: asyncio.Future[bool]) -> None: """Connect to the host. Does not process messages yet.""" # pylint: disable-next=import-outside-toplevel import paho.mqtt.client as mqtt result: int | None = None + self._available_future = client_available + self._should_reconnect = True try: - result = await self.hass.async_add_executor_job( - self._mqttc.connect, - self.conf[CONF_BROKER], - self.conf.get(CONF_PORT, DEFAULT_PORT), - self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE), - ) + async with self._connection_lock, self._async_connect_in_executor(): + result = await self.hass.async_add_executor_job( + self._mqttc.connect, + self.conf[CONF_BROKER], + self.conf.get(CONF_PORT, DEFAULT_PORT), + self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE), + ) except OSError as err: _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) + self._async_connection_result(False) + finally: + if result is not None and result != 0: + if result is not None: + _LOGGER.error( + "Failed to connect to MQTT server: %s", + mqtt.error_string(result), + ) + self._async_connection_result(False) - if result is not None and result != 0: - _LOGGER.error( - "Failed to connect to MQTT server: %s", mqtt.error_string(result) + @callback + def _async_connection_result(self, connected: bool) -> None: + """Handle a connection result.""" + if self._available_future and not self._available_future.done(): + self._available_future.set_result(connected) + + if connected: + self._async_cancel_reconnect() + elif self._should_reconnect and not self._reconnect_task: + self._reconnect_task = self.config_entry.async_create_background_task( + self.hass, self._reconnect_loop(), "mqtt reconnect loop" ) - self._mqttc.loop_start() + @callback + def _async_cancel_reconnect(self) -> None: + """Cancel the reconnect task.""" + if self._reconnect_task: + self._reconnect_task.cancel() + self._reconnect_task = None + + async def _reconnect_loop(self) -> None: + """Reconnect to the MQTT server.""" + while True: + if not self.connected: + try: + async with self._connection_lock, self._async_connect_in_executor(): + await self.hass.async_add_executor_job(self._mqttc.reconnect) + except OSError as err: + _LOGGER.debug( + "Error re-connecting to MQTT server due to exception: %s", err + ) + + await asyncio.sleep(RECONNECT_INTERVAL_SECONDS) async def async_disconnect(self) -> None: """Stop the MQTT client.""" - def stop() -> None: - """Stop the MQTT client.""" - # Do not disconnect, we want the broker to always publish will - self._mqttc.loop_stop() - def no_more_acks() -> bool: """Return False if there are unprocessed ACKs.""" return not any(not op.is_set() for op in self._pending_operations.values()) @@ -549,8 +707,10 @@ class MQTT: await self._pending_operations_condition.wait_for(no_more_acks) # stop the MQTT loop - async with self._paho_lock: - await self.hass.async_add_executor_job(stop) + async with self._connection_lock: + self._should_reconnect = False + self._async_cancel_reconnect() + self._mqttc.disconnect() @callback def async_restore_tracked_subscriptions( @@ -689,11 +849,8 @@ class MQTT: subscriptions: dict[str, int] = self._pending_subscriptions self._pending_subscriptions = {} - async with self._paho_lock: - subscription_list = list(subscriptions.items()) - result, mid = await self.hass.async_add_executor_job( - self._mqttc.subscribe, subscription_list - ) + subscription_list = list(subscriptions.items()) + result, mid = self._mqttc.subscribe(subscription_list) for topic, qos in subscriptions.items(): _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) @@ -712,17 +869,15 @@ class MQTT: topics = list(self._pending_unsubscribes) self._pending_unsubscribes = set() - async with self._paho_lock: - result, mid = await self.hass.async_add_executor_job( - self._mqttc.unsubscribe, topics - ) + result, mid = self._mqttc.unsubscribe(topics) _raise_on_error(result) for topic in topics: _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) await self._wait_for_mid(mid) - def _mqtt_on_connect( + @callback + def _async_mqtt_on_connect( self, _mqttc: mqtt.Client, _userdata: None, @@ -746,7 +901,7 @@ class MQTT: return self.connected = True - dispatcher_send(self.hass, MQTT_CONNECTED) + async_dispatcher_send(self.hass, MQTT_CONNECTED) _LOGGER.info( "Connected to MQTT server %s:%s (%s)", self.conf[CONF_BROKER], @@ -754,7 +909,7 @@ class MQTT: result_code, ) - self.hass.create_task(self._async_resubscribe()) + self.hass.async_create_task(self._async_resubscribe()) if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH): @@ -771,13 +926,17 @@ class MQTT: ) birth_message = PublishMessage(**birth) - asyncio.run_coroutine_threadsafe( - publish_birth_message(birth_message), self.hass.loop + self.config_entry.async_create_background_task( + self.hass, + publish_birth_message(birth_message), + name="mqtt birth message", ) else: # Update subscribe cooldown period to a shorter time self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN) + self._async_connection_result(True) + async def _async_resubscribe(self) -> None: """Resubscribe on reconnect.""" self._max_qos.clear() @@ -796,16 +955,6 @@ class MQTT: ) await self._async_perform_subscriptions() - def _mqtt_on_message( - self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage - ) -> None: - """Message received callback.""" - # MQTT messages tend to be high volume, - # and since they come in via a thread and need to be processed in the event loop, - # we want to avoid hass.add_job since most of the time is spent calling - # inspect to figure out how to run the callback. - self.loop.call_soon_threadsafe(self._mqtt_handle_message, msg) - @lru_cache(None) # pylint: disable=method-cache-max-size-none def _matching_subscriptions(self, topic: str) -> list[Subscription]: subscriptions: list[Subscription] = [] @@ -819,7 +968,9 @@ class MQTT: return subscriptions @callback - def _mqtt_handle_message(self, msg: mqtt.MQTTMessage) -> None: + def _async_mqtt_on_message( + self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage + ) -> None: topic = msg.topic # msg.topic is a property that decodes the topic to a string # every time it is accessed. Save the result to avoid @@ -878,7 +1029,8 @@ class MQTT: self.hass.async_run_hass_job(subscription.job, receive_msg) self._mqtt_data.state_write_requests.process_write_state_requests(msg) - def _mqtt_on_callback( + @callback + def _async_mqtt_on_callback( self, _mqttc: mqtt.Client, _userdata: None, @@ -890,7 +1042,7 @@ class MQTT: # The callback signature for on_unsubscribe is different from on_subscribe # see https://github.com/eclipse/paho.mqtt.python/issues/687 # properties and reasoncodes are not used in Home Assistant - self.hass.create_task(self._mqtt_handle_mid(mid)) + self.hass.async_create_task(self._mqtt_handle_mid(mid)) async def _mqtt_handle_mid(self, mid: int) -> None: # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid @@ -906,7 +1058,8 @@ class MQTT: if mid not in self._pending_operations: self._pending_operations[mid] = asyncio.Event() - def _mqtt_on_disconnect( + @callback + def _async_mqtt_on_disconnect( self, _mqttc: mqtt.Client, _userdata: None, @@ -914,8 +1067,19 @@ class MQTT: properties: mqtt.Properties | None = None, ) -> None: """Disconnected callback.""" + self._async_on_disconnect(result_code) + + @callback + def _async_on_disconnect(self, result_code: int) -> None: + if not self.connected: + # This function is re-entrant and may be called multiple times + # when there is a broken pipe error. + return + # If disconnect is called before the connect + # result is set make sure the first connection result is set + self._async_connection_result(False) self.connected = False - dispatcher_send(self.hass, MQTT_DISCONNECTED) + async_dispatcher_send(self.hass, MQTT_DISCONNECTED) _LOGGER.warning( "Disconnected from MQTT server %s:%s (%s)", self.conf[CONF_BROKER], diff --git a/tests/common.py b/tests/common.py index d53db1beb37..b5fe0f7bae1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -452,7 +452,7 @@ def async_fire_mqtt_message( mqtt_data: MqttData = hass.data["mqtt"] assert mqtt_data.client - mqtt_data.client._mqtt_handle_message(msg) + mqtt_data.client._async_mqtt_on_message(Mock(), None, msg) fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 3e444e8d4c8..37f7e0cf587 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -4,17 +4,22 @@ import asyncio from copy import deepcopy from datetime import datetime, timedelta import json +import socket import ssl from typing import Any, TypedDict from unittest.mock import ANY, MagicMock, call, mock_open, patch from freezegun.api import FrozenDateTimeFactory +import paho.mqtt.client as paho_mqtt import pytest import voluptuous as vol from homeassistant.components import mqtt from homeassistant.components.mqtt import debug_info -from homeassistant.components.mqtt.client import EnsureJobAfterCooldown +from homeassistant.components.mqtt.client import ( + RECONNECT_INTERVAL_SECONDS, + EnsureJobAfterCooldown, +) from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA from homeassistant.components.mqtt.models import ( MessageCallbackType, @@ -146,7 +151,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop( hass.bus.fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() await hass.async_block_till_done() - assert mqtt_client_mock.loop_stop.call_count == 1 + assert mqtt_client_mock.disconnect.call_count == 1 async def test_mqtt_await_ack_at_disconnect( @@ -161,8 +166,14 @@ async def test_mqtt_await_ack_at_disconnect( rc = 0 with patch("paho.mqtt.client.Client") as mock_client: - mock_client().connect = MagicMock(return_value=0) - mock_client().publish = MagicMock(return_value=FakeInfo()) + mqtt_client = mock_client.return_value + mqtt_client.connect = MagicMock( + return_value=0, + side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe( + mqtt_client.on_connect, mqtt_client, None, 0, 0, 0 + ), + ) + mqtt_client.publish = MagicMock(return_value=FakeInfo()) entry = MockConfigEntry( domain=mqtt.DOMAIN, data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, @@ -1669,6 +1680,7 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown( the subscribe cool down period has ended. """ mqtt_mock = await mqtt_mock_entry() + mqtt_client_mock.subscribe.reset_mock() # Fake that the client is connected mqtt_mock().connected = True @@ -1925,6 +1937,7 @@ async def test_canceling_debouncer_on_shutdown( """Test canceling the debouncer when HA shuts down.""" mqtt_mock = await mqtt_mock_entry() + mqtt_client_mock.subscribe.reset_mock() # Fake that the client is connected mqtt_mock().connected = True @@ -2008,7 +2021,7 @@ async def test_initial_setup_logs_error( """Test for setup failure if initial client connection fails.""" entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) entry.add_to_hass(hass) - mqtt_client_mock.connect.return_value = 1 + mqtt_client_mock.connect.side_effect = MagicMock(return_value=1) try: assert await hass.config_entries.async_setup(entry.entry_id) except HomeAssistantError: @@ -2230,7 +2243,12 @@ async def test_handle_mqtt_timeout_on_callback( mock_client = mock_client.return_value mock_client.publish.return_value = FakeInfo() mock_client.subscribe.side_effect = _mock_ack - mock_client.connect.return_value = 0 + mock_client.connect = MagicMock( + return_value=0, + side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe( + mock_client.on_connect, mock_client, None, 0, 0, 0 + ), + ) entry = MockConfigEntry( domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"} @@ -4144,3 +4162,179 @@ async def test_multi_platform_discovery( ) is not None ) + + +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +async def test_auto_reconnect( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test reconnection is automatically done.""" + mqtt_mock = await mqtt_mock_entry() + await hass.async_block_till_done() + assert mqtt_mock.connected is True + mqtt_client_mock.reconnect.reset_mock() + + mqtt_client_mock.disconnect() + mqtt_client_mock.on_disconnect(None, None, 0) + await hass.async_block_till_done() + + mqtt_client_mock.reconnect.side_effect = OSError("foo") + async_fire_time_changed( + hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS) + ) + await hass.async_block_till_done() + assert len(mqtt_client_mock.reconnect.mock_calls) == 1 + assert "Error re-connecting to MQTT server due to exception: foo" in caplog.text + + mqtt_client_mock.reconnect.side_effect = None + async_fire_time_changed( + hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS) + ) + await hass.async_block_till_done() + assert len(mqtt_client_mock.reconnect.mock_calls) == 2 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + + mqtt_client_mock.disconnect() + mqtt_client_mock.on_disconnect(None, None, 0) + await hass.async_block_till_done() + + async_fire_time_changed( + hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS) + ) + await hass.async_block_till_done() + # Should not reconnect after stop + assert len(mqtt_client_mock.reconnect.mock_calls) == 2 + + +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +async def test_server_sock_connect_and_disconnect( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + calls: list[ReceiveMessage], + record_calls: MessageCallbackType, +) -> None: + """Test handling the socket connected and disconnected.""" + mqtt_mock = await mqtt_mock_entry() + await hass.async_block_till_done() + assert mqtt_mock.connected is True + + mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS + + client, server = socket.socketpair( + family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0 + ) + client.setblocking(False) + server.setblocking(False) + mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client) + mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client) + await hass.async_block_till_done() + + server.close() # mock the server closing the connection on us + + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + + mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_CONN_LOST + mqtt_client_mock.on_socket_unregister_write(mqtt_client_mock, None, client) + mqtt_client_mock.on_socket_close(mqtt_client_mock, None, client) + mqtt_client_mock.on_disconnect(mqtt_client_mock, None, client) + await hass.async_block_till_done() + unsub() + + # Should have failed + assert len(calls) == 0 + + +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +async def test_client_sock_failure_after_connect( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + calls: list[ReceiveMessage], + record_calls: MessageCallbackType, +) -> None: + """Test handling the socket connected and disconnected.""" + mqtt_mock = await mqtt_mock_entry() + # Fake that the client is connected + mqtt_mock().connected = True + await hass.async_block_till_done() + assert mqtt_mock.connected is True + + mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS + + client, server = socket.socketpair( + family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0 + ) + client.setblocking(False) + server.setblocking(False) + mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client) + mqtt_client_mock.on_socket_register_writer(mqtt_client_mock, None, client) + await hass.async_block_till_done() + + mqtt_client_mock.loop_write.side_effect = OSError("foo") + client.close() # close the client socket out from under the client + + assert mqtt_mock.connected is True + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + + unsub() + # Should have failed + assert len(calls) == 0 + + +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +async def test_loop_write_failure( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling the socket connected and disconnected.""" + mqtt_mock = await mqtt_mock_entry() + await hass.async_block_till_done() + assert mqtt_mock.connected is True + + mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS + + client, server = socket.socketpair( + family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0 + ) + client.setblocking(False) + server.setblocking(False) + mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client) + mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client) + mqtt_client_mock.loop_write.return_value = paho_mqtt.MQTT_ERR_CONN_LOST + mqtt_client_mock.loop_read.return_value = paho_mqtt.MQTT_ERR_CONN_LOST + + # Fill up the outgoing buffer to ensure that loop_write + # and loop_read are called that next time control is + # returned to the event loop + try: + for _ in range(1000): + server.send(b"long" * 100) + except BlockingIOError: + pass + + server.close() + # Once for the reader callback + await hass.async_block_till_done() + # Another for the writer callback + await hass.async_block_till_done() + # Final for the disconnect callback + await hass.async_block_till_done() + + assert "Disconnected from MQTT server mock-broker:1883 (7)" in caplog.text diff --git a/tests/components/tasmota/test_common.py b/tests/components/tasmota/test_common.py index 360794e280f..499e732719c 100644 --- a/tests/components/tasmota/test_common.py +++ b/tests/components/tasmota/test_common.py @@ -163,7 +163,7 @@ async def help_test_availability_when_connection_lost( # Disconnected from MQTT server -> state changed to unavailable mqtt_mock.connected = False - await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) + mqtt_client_mock.on_disconnect(None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() @@ -172,7 +172,7 @@ async def help_test_availability_when_connection_lost( # Reconnected to MQTT server -> state still unavailable mqtt_mock.connected = True - await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) + mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() @@ -224,7 +224,7 @@ async def help_test_deep_sleep_availability_when_connection_lost( # Disconnected from MQTT server -> state changed to unavailable mqtt_mock.connected = False - await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) + mqtt_client_mock.on_disconnect(None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() @@ -233,7 +233,7 @@ async def help_test_deep_sleep_availability_when_connection_lost( # Reconnected to MQTT server -> state no longer unavailable mqtt_mock.connected = True - await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) + mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() @@ -476,7 +476,7 @@ async def help_test_availability_poll_state( # Disconnected from MQTT server mqtt_mock.connected = False - await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) + mqtt_client_mock.on_disconnect(None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() @@ -484,7 +484,7 @@ async def help_test_availability_poll_state( # Reconnected to MQTT server mqtt_mock.connected = True - await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) + mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() await hass.async_block_till_done() await hass.async_block_till_done() diff --git a/tests/conftest.py b/tests/conftest.py index a38da17f44b..3a95e0e58b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -904,26 +904,45 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None, self.rc = 0 with patch("paho.mqtt.client.Client") as mock_client: + # The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe + # callbacks to simulate the behavior of the real MQTT client which will + # not be synchronous. @ha.callback def _async_fire_mqtt_message(topic, payload, qos, retain): async_fire_mqtt_message(hass, topic, payload, qos, retain) mid = get_mid() - mock_client.on_publish(0, 0, mid) + hass.loop.call_soon(mock_client.on_publish, 0, 0, mid) return FakeInfo(mid) def _subscribe(topic, qos=0): mid = get_mid() - mock_client.on_subscribe(0, 0, mid) + hass.loop.call_soon(mock_client.on_subscribe, 0, 0, mid) return (0, mid) def _unsubscribe(topic): mid = get_mid() - mock_client.on_unsubscribe(0, 0, mid) + hass.loop.call_soon(mock_client.on_unsubscribe, 0, 0, mid) return (0, mid) + def _connect(*args, **kwargs): + # Connect always calls reconnect once, but we + # mock it out so we call reconnect to simulate + # the behavior. + mock_client.reconnect() + hass.loop.call_soon_threadsafe( + mock_client.on_connect, mock_client, None, 0, 0, 0 + ) + mock_client.on_socket_open( + mock_client, None, Mock(fileno=Mock(return_value=-1)) + ) + mock_client.on_socket_register_write( + mock_client, None, Mock(fileno=Mock(return_value=-1)) + ) + return 0 + mock_client = mock_client.return_value - mock_client.connect.return_value = 0 + mock_client.connect.side_effect = _connect mock_client.subscribe.side_effect = _subscribe mock_client.unsubscribe.side_effect = _unsubscribe mock_client.publish.side_effect = _async_fire_mqtt_message @@ -985,6 +1004,7 @@ async def _mqtt_mock_entry( # connected set to True to get a more realistic behavior when subscribing mock_mqtt_instance.connected = True + mqtt_client_mock.on_connect(mqtt_client_mock, None, 0, 0, 0) async_dispatcher_send(hass, mqtt.MQTT_CONNECTED) await hass.async_block_till_done()