mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Handle cancellation in ServiceRegistry.async_call (#33644)
This commit is contained in:
parent
d7e9959442
commit
bf1b408038
@ -28,6 +28,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
import uuid
|
||||
|
||||
@ -1224,29 +1225,57 @@ class ServiceRegistry:
|
||||
context=context,
|
||||
)
|
||||
|
||||
coro = self._execute_service(handler, service_call)
|
||||
if not blocking:
|
||||
self._hass.async_create_task(self._safe_execute(handler, service_call))
|
||||
self._run_service_in_background(coro, service_call)
|
||||
return None
|
||||
|
||||
task = self._hass.async_create_task(coro)
|
||||
try:
|
||||
async with timeout(limit):
|
||||
await asyncio.shield(self._execute_service(handler, service_call))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
await asyncio.wait({task}, timeout=limit)
|
||||
except asyncio.CancelledError:
|
||||
# Task calling us was cancelled, so cancel service call task, and wait for
|
||||
# it to be cancelled, within reason, before leaving.
|
||||
_LOGGER.debug("Service call was cancelled: %s", service_call)
|
||||
task.cancel()
|
||||
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
|
||||
raise
|
||||
|
||||
async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
|
||||
"""Execute a service and catch exceptions."""
|
||||
try:
|
||||
await self._execute_service(handler, service_call)
|
||||
except Unauthorized:
|
||||
_LOGGER.warning(
|
||||
"Unauthorized service called %s/%s",
|
||||
service_call.domain,
|
||||
service_call.service,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error executing service %s", service_call)
|
||||
if task.cancelled():
|
||||
# Service call task was cancelled some other way, such as during shutdown.
|
||||
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||
raise asyncio.CancelledError
|
||||
if task.done():
|
||||
# Propagate any exceptions that might have happened during service call.
|
||||
task.result()
|
||||
# Service call completed successfully!
|
||||
return True
|
||||
# Service call task did not complete before timeout expired.
|
||||
# Let it keep running in background.
|
||||
self._run_service_in_background(task, service_call)
|
||||
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
|
||||
return False
|
||||
|
||||
def _run_service_in_background(
|
||||
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
|
||||
) -> None:
|
||||
"""Run service call in background, catching and logging any exceptions."""
|
||||
|
||||
async def catch_exceptions() -> None:
|
||||
try:
|
||||
await coro_or_task
|
||||
except Unauthorized:
|
||||
_LOGGER.warning(
|
||||
"Unauthorized service called %s/%s",
|
||||
service_call.domain,
|
||||
service_call.service,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error executing service: %s", service_call)
|
||||
|
||||
self._hass.async_create_task(catch_exceptions())
|
||||
|
||||
async def _execute_service(
|
||||
self, handler: Service, service_call: ServiceCall
|
||||
|
@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass):
|
||||
assert len(runs) == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cancel_call", [True, False])
|
||||
async def test_cancel_service_task(hass, cancel_call):
|
||||
"""Test cancellation."""
|
||||
service_called = asyncio.Event()
|
||||
service_cancelled = False
|
||||
|
||||
async def service_handler(call):
|
||||
nonlocal service_cancelled
|
||||
service_called.set()
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
service_cancelled = True
|
||||
raise
|
||||
|
||||
hass.services.async_register("test_domain", "test_service", service_handler)
|
||||
call_task = hass.async_create_task(
|
||||
hass.services.async_call("test_domain", "test_service", blocking=True)
|
||||
)
|
||||
|
||||
tasks_1 = asyncio.all_tasks()
|
||||
await asyncio.wait_for(service_called.wait(), timeout=1)
|
||||
tasks_2 = asyncio.all_tasks() - tasks_1
|
||||
assert len(tasks_2) == 1
|
||||
service_task = tasks_2.pop()
|
||||
|
||||
if cancel_call:
|
||||
call_task.cancel()
|
||||
else:
|
||||
service_task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await call_task
|
||||
|
||||
assert service_cancelled
|
||||
|
||||
|
||||
def test_valid_entity_id():
|
||||
"""Test valid entity ID."""
|
||||
for invalid in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user