Simplify MQTT mid handling (#116522)

* Simplify MQTT mid handling

switch from asyncio.Event to asyncio.Future

* preen

* preen

* preen
This commit is contained in:
J. Nick Koston 2024-05-01 11:03:10 -05:00 committed by GitHub
parent 2fe17acaf7
commit 25df41475a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -405,8 +405,7 @@ class MQTT:
self._cleanup_on_unload: list[Callable[[], None]] = [] self._cleanup_on_unload: list[Callable[[], None]] = []
self._connection_lock = asyncio.Lock() self._connection_lock = asyncio.Lock()
self._pending_operations: dict[int, asyncio.Event] = {} self._pending_operations: dict[int, asyncio.Future[None]] = {}
self._pending_operations_condition = asyncio.Condition()
self._subscribe_debouncer = EnsureJobAfterCooldown( self._subscribe_debouncer = EnsureJobAfterCooldown(
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
) )
@ -679,10 +678,6 @@ class MQTT:
async def async_disconnect(self) -> None: async def async_disconnect(self) -> None:
"""Stop the MQTT client.""" """Stop the MQTT client."""
def no_more_acks() -> bool:
"""Return False if there are unprocessed ACKs."""
return not any(not op.is_set() for op in self._pending_operations.values())
# stop waiting for any pending subscriptions # stop waiting for any pending subscriptions
await self._subscribe_debouncer.async_cleanup() await self._subscribe_debouncer.async_cleanup()
# reset timeout to initial subscribe cooldown # reset timeout to initial subscribe cooldown
@ -693,8 +688,8 @@ class MQTT:
await self._async_perform_unsubscribes() await self._async_perform_unsubscribes()
# wait for ACKs to be processed # wait for ACKs to be processed
async with self._pending_operations_condition: if pending := self._pending_operations.values():
await self._pending_operations_condition.wait_for(no_more_acks) await asyncio.wait(pending)
# stop the MQTT loop # stop the MQTT loop
async with self._connection_lock: async with self._connection_lock:
@ -1051,23 +1046,20 @@ class MQTT:
# The callback signature for on_unsubscribe is different from on_subscribe # The callback signature for on_unsubscribe is different from on_subscribe
# see https://github.com/eclipse/paho.mqtt.python/issues/687 # see https://github.com/eclipse/paho.mqtt.python/issues/687
# properties and reason codes are not used in Home Assistant # properties and reason codes are not used in Home Assistant
self.config_entry.async_create_task( future = self._async_get_mid_future(mid)
self.hass, self._mqtt_handle_mid(mid), name=f"mqtt handle mid {mid}" if future.done():
) _LOGGER.warning("Received duplicate mid: %s", mid)
return
future.set_result(None)
async def _mqtt_handle_mid(self, mid: int) -> None: @callback
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid def _async_get_mid_future(self, mid: int) -> asyncio.Future[None]:
# may be executed first. """Get the future for a mid."""
async with self._pending_operations_condition: if future := self._pending_operations.get(mid):
if mid not in self._pending_operations: return future
self._pending_operations[mid] = asyncio.Event() future = self.hass.loop.create_future()
self._pending_operations[mid].set() self._pending_operations[mid] = future
return future
async def _register_mid(self, mid: int) -> None:
"""Create Event for an expected ACK."""
async with self._pending_operations_condition:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
@callback @callback
def _async_mqtt_on_disconnect( def _async_mqtt_on_disconnect(
@ -1098,23 +1090,28 @@ class MQTT:
result_code, result_code,
) )
@callback
def _async_timeout_mid(self, future: asyncio.Future[None]) -> None:
"""Timeout waiting for a mid."""
if not future.done():
future.set_exception(asyncio.TimeoutError)
async def _wait_for_mid(self, mid: int) -> None: async def _wait_for_mid(self, mid: int) -> None:
"""Wait for ACK from broker.""" """Wait for ACK from broker."""
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first. # may be executed first.
await self._register_mid(mid) future = self._async_get_mid_future(mid)
loop = self.hass.loop
timer_handle = loop.call_later(TIMEOUT_ACK, self._async_timeout_mid, future)
try: try:
async with asyncio.timeout(TIMEOUT_ACK): await future
await self._pending_operations[mid].wait()
except TimeoutError: except TimeoutError:
_LOGGER.warning( _LOGGER.warning(
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
) )
finally: finally:
async with self._pending_operations_condition: timer_handle.cancel()
# Cleanup ACK sync buffer
del self._pending_operations[mid] del self._pending_operations[mid]
self._pending_operations_condition.notify_all()
async def _discovery_cooldown(self) -> None: async def _discovery_cooldown(self) -> None:
"""Wait until all discovery and subscriptions are processed.""" """Wait until all discovery and subscriptions are processed."""