Simplify MQTT mid handling (#116522)

* Simplify MQTT mid handling

switch from asyncio.Event to asyncio.Future

* preen

* preen

* preen
This commit is contained in:
J. Nick Koston 2024-05-01 11:03:10 -05:00 committed by GitHub
parent 2fe17acaf7
commit 25df41475a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -405,8 +405,7 @@ class MQTT:
self._cleanup_on_unload: list[Callable[[], None]] = []
self._connection_lock = asyncio.Lock()
self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition()
self._pending_operations: dict[int, asyncio.Future[None]] = {}
self._subscribe_debouncer = EnsureJobAfterCooldown(
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
)
@ -679,10 +678,6 @@ class MQTT:
async def async_disconnect(self) -> None:
"""Stop the MQTT client."""
def no_more_acks() -> bool:
"""Return False if there are unprocessed ACKs."""
return not any(not op.is_set() for op in self._pending_operations.values())
# stop waiting for any pending subscriptions
await self._subscribe_debouncer.async_cleanup()
# reset timeout to initial subscribe cooldown
@ -693,8 +688,8 @@ class MQTT:
await self._async_perform_unsubscribes()
# wait for ACKs to be processed
async with self._pending_operations_condition:
await self._pending_operations_condition.wait_for(no_more_acks)
if pending := self._pending_operations.values():
await asyncio.wait(pending)
# stop the MQTT loop
async with self._connection_lock:
@ -1051,23 +1046,20 @@ class MQTT:
# The callback signature for on_unsubscribe is different from on_subscribe
# see https://github.com/eclipse/paho.mqtt.python/issues/687
# properties and reason codes are not used in Home Assistant
self.config_entry.async_create_task(
self.hass, self._mqtt_handle_mid(mid), name=f"mqtt handle mid {mid}"
)
future = self._async_get_mid_future(mid)
if future.done():
_LOGGER.warning("Received duplicate mid: %s", mid)
return
future.set_result(None)
async def _mqtt_handle_mid(self, mid: int) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
async with self._pending_operations_condition:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
self._pending_operations[mid].set()
async def _register_mid(self, mid: int) -> None:
"""Create Event for an expected ACK."""
async with self._pending_operations_condition:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
@callback
def _async_get_mid_future(self, mid: int) -> asyncio.Future[None]:
"""Get the future for a mid."""
if future := self._pending_operations.get(mid):
return future
future = self.hass.loop.create_future()
self._pending_operations[mid] = future
return future
@callback
def _async_mqtt_on_disconnect(
@ -1098,23 +1090,28 @@ class MQTT:
result_code,
)
@callback
def _async_timeout_mid(self, future: asyncio.Future[None]) -> None:
"""Timeout waiting for a mid."""
if not future.done():
future.set_exception(asyncio.TimeoutError)
async def _wait_for_mid(self, mid: int) -> None:
"""Wait for ACK from broker."""
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
await self._register_mid(mid)
future = self._async_get_mid_future(mid)
loop = self.hass.loop
timer_handle = loop.call_later(TIMEOUT_ACK, self._async_timeout_mid, future)
try:
async with asyncio.timeout(TIMEOUT_ACK):
await self._pending_operations[mid].wait()
await future
except TimeoutError:
_LOGGER.warning(
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
)
finally:
async with self._pending_operations_condition:
# Cleanup ACK sync buffer
timer_handle.cancel()
del self._pending_operations[mid]
self._pending_operations_condition.notify_all()
async def _discovery_cooldown(self) -> None:
"""Wait until all discovery and subscriptions are processed."""