mirror of
https://github.com/home-assistant/core.git
synced 2025-11-08 18:39:30 +00:00
Handle cancellation in ServiceRegistry.async_call (#33644)
This commit is contained in:
@@ -28,6 +28,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
import uuid
|
||||
|
||||
@@ -1224,29 +1225,57 @@ class ServiceRegistry:
|
||||
context=context,
|
||||
)
|
||||
|
||||
coro = self._execute_service(handler, service_call)
|
||||
if not blocking:
|
||||
self._hass.async_create_task(self._safe_execute(handler, service_call))
|
||||
self._run_service_in_background(coro, service_call)
|
||||
return None
|
||||
|
||||
task = self._hass.async_create_task(coro)
|
||||
try:
|
||||
async with timeout(limit):
|
||||
await asyncio.shield(self._execute_service(handler, service_call))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
await asyncio.wait({task}, timeout=limit)
|
||||
except asyncio.CancelledError:
|
||||
# Task calling us was cancelled, so cancel service call task, and wait for
|
||||
# it to be cancelled, within reason, before leaving.
|
||||
_LOGGER.debug("Service call was cancelled: %s", service_call)
|
||||
task.cancel()
|
||||
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
|
||||
raise
|
||||
|
||||
async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
|
||||
"""Execute a service and catch exceptions."""
|
||||
try:
|
||||
await self._execute_service(handler, service_call)
|
||||
except Unauthorized:
|
||||
_LOGGER.warning(
|
||||
"Unauthorized service called %s/%s",
|
||||
service_call.domain,
|
||||
service_call.service,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error executing service %s", service_call)
|
||||
if task.cancelled():
|
||||
# Service call task was cancelled some other way, such as during shutdown.
|
||||
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||
raise asyncio.CancelledError
|
||||
if task.done():
|
||||
# Propagate any exceptions that might have happened during service call.
|
||||
task.result()
|
||||
# Service call completed successfully!
|
||||
return True
|
||||
# Service call task did not complete before timeout expired.
|
||||
# Let it keep running in background.
|
||||
self._run_service_in_background(task, service_call)
|
||||
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
|
||||
return False
|
||||
|
||||
def _run_service_in_background(
|
||||
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
|
||||
) -> None:
|
||||
"""Run service call in background, catching and logging any exceptions."""
|
||||
|
||||
async def catch_exceptions() -> None:
|
||||
try:
|
||||
await coro_or_task
|
||||
except Unauthorized:
|
||||
_LOGGER.warning(
|
||||
"Unauthorized service called %s/%s",
|
||||
service_call.domain,
|
||||
service_call.service,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error executing service: %s", service_call)
|
||||
|
||||
self._hass.async_create_task(catch_exceptions())
|
||||
|
||||
async def _execute_service(
|
||||
self, handler: Service, service_call: ServiceCall
|
||||
|
||||
Reference in New Issue
Block a user