Fix ESPHome bluetooth client cancellation when the operation is cancelled externally (#96804)

This commit is contained in:
J. Nick Koston 2023-07-18 03:39:26 -10:00 committed by GitHub
parent 9a8fe04907
commit 6bd4ace3c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
import contextlib import contextlib
from functools import partial
import logging import logging
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
import uuid import uuid
@ -56,38 +57,40 @@ def mac_to_int(address: str) -> int:
return int(address.replace(":", ""), 16) return int(address.replace(":", ""), 16)
def _on_disconnected(task: asyncio.Task[Any], _: asyncio.Future[None]) -> None:
if task and not task.done():
task.cancel()
def verify_connected(func: _WrapFuncType) -> _WrapFuncType: def verify_connected(func: _WrapFuncType) -> _WrapFuncType:
"""Define a wrapper throw BleakError if not connected.""" """Define a wrapper throw BleakError if not connected."""
async def _async_wrap_bluetooth_connected_operation( async def _async_wrap_bluetooth_connected_operation(
self: ESPHomeClient, *args: Any, **kwargs: Any self: ESPHomeClient, *args: Any, **kwargs: Any
) -> Any: ) -> Any:
loop = self._loop # pylint: disable=protected-access # pylint: disable=protected-access
disconnected_futures = ( loop = self._loop
self._disconnected_futures # pylint: disable=protected-access disconnected_futures = self._disconnected_futures
)
disconnected_future = loop.create_future() disconnected_future = loop.create_future()
disconnect_handler = partial(_on_disconnected, asyncio.current_task(loop))
disconnected_future.add_done_callback(disconnect_handler)
disconnected_futures.add(disconnected_future) disconnected_futures.add(disconnected_future)
task = asyncio.current_task(loop)
def _on_disconnected(fut: asyncio.Future[None]) -> None:
if task and not task.done():
task.cancel()
disconnected_future.add_done_callback(_on_disconnected)
try: try:
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
except asyncio.CancelledError as ex: except asyncio.CancelledError as ex:
source_name = self._source_name # pylint: disable=protected-access if not disconnected_future.done():
ble_device = self._ble_device # pylint: disable=protected-access # If the disconnected future is not done, the task was cancelled
# externally and we need to raise cancelled error to avoid
# blocking the cancellation.
raise
ble_device = self._ble_device
raise BleakError( raise BleakError(
f"{source_name}: {ble_device.name} - {ble_device.address}: " f"{self._source_name }: {ble_device.name} - {ble_device.address}: "
"Disconnected during operation" "Disconnected during operation"
) from ex ) from ex
finally: finally:
disconnected_futures.discard(disconnected_future) disconnected_futures.discard(disconnected_future)
disconnected_future.remove_done_callback(_on_disconnected) disconnected_future.remove_done_callback(disconnect_handler)
return cast(_WrapFuncType, _async_wrap_bluetooth_connected_operation) return cast(_WrapFuncType, _async_wrap_bluetooth_connected_operation)