mirror of
https://github.com/home-assistant/core.git
synced 2025-07-11 07:17:12 +00:00
Handle cancellation in ServiceRegistry.async_call (#33644)
This commit is contained in:
parent
d7e9959442
commit
bf1b408038
@ -28,6 +28,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -1224,29 +1225,57 @@ class ServiceRegistry:
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
coro = self._execute_service(handler, service_call)
|
||||||
if not blocking:
|
if not blocking:
|
||||||
self._hass.async_create_task(self._safe_execute(handler, service_call))
|
self._run_service_in_background(coro, service_call)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
task = self._hass.async_create_task(coro)
|
||||||
try:
|
try:
|
||||||
async with timeout(limit):
|
await asyncio.wait({task}, timeout=limit)
|
||||||
await asyncio.shield(self._execute_service(handler, service_call))
|
except asyncio.CancelledError:
|
||||||
|
# Task calling us was cancelled, so cancel service call task, and wait for
|
||||||
|
# it to be cancelled, within reason, before leaving.
|
||||||
|
_LOGGER.debug("Service call was cancelled: %s", service_call)
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if task.cancelled():
|
||||||
|
# Service call task was cancelled some other way, such as during shutdown.
|
||||||
|
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
if task.done():
|
||||||
|
# Propagate any exceptions that might have happened during service call.
|
||||||
|
task.result()
|
||||||
|
# Service call completed successfully!
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
# Service call task did not complete before timeout expired.
|
||||||
|
# Let it keep running in background.
|
||||||
|
self._run_service_in_background(task, service_call)
|
||||||
|
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
|
def _run_service_in_background(
|
||||||
"""Execute a service and catch exceptions."""
|
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
|
||||||
|
) -> None:
|
||||||
|
"""Run service call in background, catching and logging any exceptions."""
|
||||||
|
|
||||||
|
async def catch_exceptions() -> None:
|
||||||
try:
|
try:
|
||||||
await self._execute_service(handler, service_call)
|
await coro_or_task
|
||||||
except Unauthorized:
|
except Unauthorized:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Unauthorized service called %s/%s",
|
"Unauthorized service called %s/%s",
|
||||||
service_call.domain,
|
service_call.domain,
|
||||||
service_call.service,
|
service_call.service,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_LOGGER.debug("Service was cancelled: %s", service_call)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception("Error executing service %s", service_call)
|
_LOGGER.exception("Error executing service: %s", service_call)
|
||||||
|
|
||||||
|
self._hass.async_create_task(catch_exceptions())
|
||||||
|
|
||||||
async def _execute_service(
|
async def _execute_service(
|
||||||
self, handler: Service, service_call: ServiceCall
|
self, handler: Service, service_call: ServiceCall
|
||||||
|
@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass):
|
|||||||
assert len(runs) == 3
|
assert len(runs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("cancel_call", [True, False])
|
||||||
|
async def test_cancel_service_task(hass, cancel_call):
|
||||||
|
"""Test cancellation."""
|
||||||
|
service_called = asyncio.Event()
|
||||||
|
service_cancelled = False
|
||||||
|
|
||||||
|
async def service_handler(call):
|
||||||
|
nonlocal service_cancelled
|
||||||
|
service_called.set()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
service_cancelled = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
hass.services.async_register("test_domain", "test_service", service_handler)
|
||||||
|
call_task = hass.async_create_task(
|
||||||
|
hass.services.async_call("test_domain", "test_service", blocking=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks_1 = asyncio.all_tasks()
|
||||||
|
await asyncio.wait_for(service_called.wait(), timeout=1)
|
||||||
|
tasks_2 = asyncio.all_tasks() - tasks_1
|
||||||
|
assert len(tasks_2) == 1
|
||||||
|
service_task = tasks_2.pop()
|
||||||
|
|
||||||
|
if cancel_call:
|
||||||
|
call_task.cancel()
|
||||||
|
else:
|
||||||
|
service_task.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await call_task
|
||||||
|
|
||||||
|
assert service_cancelled
|
||||||
|
|
||||||
|
|
||||||
def test_valid_entity_id():
|
def test_valid_entity_id():
|
||||||
"""Test valid entity ID."""
|
"""Test valid entity ID."""
|
||||||
for invalid in [
|
for invalid in [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user