From 3031e4733b7c4903c70d0fab4edf3e510e8013f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 24 May 2024 14:33:21 -1000 Subject: [PATCH] Reduce duplicate code to handle mqtt message replies (#118067) --- homeassistant/components/mqtt/client.py | 42 +++++++++++-------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 70582f5c107..5b38838ae39 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -681,8 +681,7 @@ class MQTT: msg_info.mid, qos, ) - _raise_on_error(msg_info.rc) - await self._async_wait_for_mid(msg_info.mid) + await self._async_wait_for_mid_or_raise(msg_info.mid, msg_info.rc) async def async_connect(self, client_available: asyncio.Future[bool]) -> None: """Connect to the host. Does not process messages yet.""" @@ -930,21 +929,19 @@ class MQTT: self._pending_subscriptions = {} subscription_list = list(subscriptions.items()) + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL): result, mid = self._mqttc.subscribe(chunk) - if _LOGGER.isEnabledFor(logging.DEBUG): + if debug_enabled: for topic, qos in subscriptions.items(): _LOGGER.debug( "Subscribing to %s, mid: %s, qos: %s", topic, mid, qos ) self._last_subscribe = time.monotonic() - if result == 0: - await self._async_wait_for_mid(mid) - else: - _raise_on_error(result) + await self._async_wait_for_mid_or_raise(mid, result) async def _async_perform_unsubscribes(self) -> None: """Perform pending MQTT client unsubscribes.""" @@ -953,15 +950,15 @@ class MQTT: topics = list(self._pending_unsubscribes) self._pending_unsubscribes = set() + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) for chunk in chunked_or_all(topics, MAX_UNSUBSCRIBES_PER_CALL): result, mid = self._mqttc.unsubscribe(chunk) - _raise_on_error(result) - if _LOGGER.isEnabledFor(logging.DEBUG): + if debug_enabled: for topic in chunk: _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) - await self._async_wait_for_mid(mid) + await self._async_wait_for_mid_or_raise(mid, result) async def _async_resubscribe_and_publish_birth_message( self, birth_message: PublishMessage @@ -1225,10 +1222,18 @@ class MQTT: if not future.done(): future.set_exception(asyncio.TimeoutError) - async def _async_wait_for_mid(self, mid: int) -> None: - """Wait for ACK from broker.""" - # Create the mid event if not created, either _mqtt_handle_mid or _async_wait_for_mid - # may be executed first. + async def _async_wait_for_mid_or_raise(self, mid: int, result_code: int) -> None: + """Wait for ACK from broker or raise on error.""" + if result_code != 0: + # pylint: disable-next=import-outside-toplevel + import paho.mqtt.client as mqtt + + raise HomeAssistantError( + f"Error talking to MQTT: {mqtt.error_string(result_code)}" + ) + + # Create the mid event if not created, either _mqtt_handle_mid or + # _async_wait_for_mid_or_raise may be executed first. future = self._async_get_mid_future(mid) loop = self.hass.loop timer_handle = loop.call_later(TIMEOUT_ACK, self._async_timeout_mid, future) @@ -1266,15 +1271,6 @@ class MQTT: ) -def _raise_on_error(result_code: int) -> None: - """Raise error if error result.""" - # pylint: disable-next=import-outside-toplevel - import paho.mqtt.client as mqtt - - if result_code and (message := mqtt.error_string(result_code)): - raise HomeAssistantError(f"Error talking to MQTT: {message}") - - def _matcher_for_topic(subscription: str) -> Callable[[str], bool]: # pylint: disable-next=import-outside-toplevel from paho.mqtt.matcher import MQTTMatcher