diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 6c7e0934a4e..5eb85b1679c 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -405,8 +405,7 @@ class MQTT: self._cleanup_on_unload: list[Callable[[], None]] = [] self._connection_lock = asyncio.Lock() - self._pending_operations: dict[int, asyncio.Event] = {} - self._pending_operations_condition = asyncio.Condition() + self._pending_operations: dict[int, asyncio.Future[None]] = {} self._subscribe_debouncer = EnsureJobAfterCooldown( INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions ) @@ -679,10 +678,6 @@ class MQTT: async def async_disconnect(self) -> None: """Stop the MQTT client.""" - def no_more_acks() -> bool: - """Return False if there are unprocessed ACKs.""" - return not any(not op.is_set() for op in self._pending_operations.values()) - # stop waiting for any pending subscriptions await self._subscribe_debouncer.async_cleanup() # reset timeout to initial subscribe cooldown @@ -693,8 +688,8 @@ class MQTT: await self._async_perform_unsubscribes() # wait for ACKs to be processed - async with self._pending_operations_condition: - await self._pending_operations_condition.wait_for(no_more_acks) + if pending := self._pending_operations.values(): + await asyncio.wait(pending) # stop the MQTT loop async with self._connection_lock: @@ -1050,24 +1045,21 @@ class MQTT: """Publish / Subscribe / Unsubscribe callback.""" # The callback signature for on_unsubscribe is different from on_subscribe # see https://github.com/eclipse/paho.mqtt.python/issues/687 - # properties and reasoncodes are not used in Home Assistant - self.config_entry.async_create_task( - self.hass, self._mqtt_handle_mid(mid), name=f"mqtt handle mid {mid}" - ) + # properties and reason codes are not used in Home Assistant + future = self._async_get_mid_future(mid) + if future.done(): + _LOGGER.warning("Received duplicate mid: %s", mid) + return + future.set_result(None) - async def _mqtt_handle_mid(self, mid: int) -> None: - # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid - # may be executed first. - async with self._pending_operations_condition: - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() - self._pending_operations[mid].set() - - async def _register_mid(self, mid: int) -> None: - """Create Event for an expected ACK.""" - async with self._pending_operations_condition: - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() + @callback + def _async_get_mid_future(self, mid: int) -> asyncio.Future[None]: + """Get the future for a mid.""" + if future := self._pending_operations.get(mid): + return future + future = self.hass.loop.create_future() + self._pending_operations[mid] = future + return future @callback def _async_mqtt_on_disconnect( @@ -1098,23 +1090,28 @@ class MQTT: result_code, ) + @callback + def _async_timeout_mid(self, future: asyncio.Future[None]) -> None: + """Timeout waiting for a mid.""" + if not future.done(): + future.set_exception(asyncio.TimeoutError) + async def _wait_for_mid(self, mid: int) -> None: """Wait for ACK from broker.""" # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # may be executed first. - await self._register_mid(mid) + future = self._async_get_mid_future(mid) + loop = self.hass.loop + timer_handle = loop.call_later(TIMEOUT_ACK, self._async_timeout_mid, future) try: - async with asyncio.timeout(TIMEOUT_ACK): - await self._pending_operations[mid].wait() + await future except TimeoutError: _LOGGER.warning( "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid ) finally: - async with self._pending_operations_condition: - # Cleanup ACK sync buffer - del self._pending_operations[mid] - self._pending_operations_condition.notify_all() + timer_handle.cancel() + del self._pending_operations[mid] async def _discovery_cooldown(self) -> None: """Wait until all discovery and subscriptions are processed."""