Handle cancellation in ServiceRegistry.async_call (#33644)

This commit is contained in:
Phil Bruckner 2020-04-04 17:36:33 -05:00 committed by GitHub
parent d7e9959442
commit bf1b408038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 18 deletions

View File

@ -28,6 +28,7 @@ from typing import (
Optional, Optional,
Set, Set,
TypeVar, TypeVar,
Union,
) )
import uuid import uuid
@ -1224,29 +1225,57 @@ class ServiceRegistry:
context=context, context=context,
) )
coro = self._execute_service(handler, service_call)
if not blocking: if not blocking:
self._hass.async_create_task(self._safe_execute(handler, service_call)) self._run_service_in_background(coro, service_call)
return None return None
task = self._hass.async_create_task(coro)
try: try:
async with timeout(limit): await asyncio.wait({task}, timeout=limit)
await asyncio.shield(self._execute_service(handler, service_call)) except asyncio.CancelledError:
return True # Task calling us was cancelled, so cancel service call task, and wait for
except asyncio.TimeoutError: # it to be cancelled, within reason, before leaving.
return False _LOGGER.debug("Service call was cancelled: %s", service_call)
task.cancel()
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
raise
async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None: if task.cancelled():
"""Execute a service and catch exceptions.""" # Service call task was cancelled some other way, such as during shutdown.
try: _LOGGER.debug("Service was cancelled: %s", service_call)
await self._execute_service(handler, service_call) raise asyncio.CancelledError
except Unauthorized: if task.done():
_LOGGER.warning( # Propagate any exceptions that might have happened during service call.
"Unauthorized service called %s/%s", task.result()
service_call.domain, # Service call completed successfully!
service_call.service, return True
) # Service call task did not complete before timeout expired.
except Exception: # pylint: disable=broad-except # Let it keep running in background.
_LOGGER.exception("Error executing service %s", service_call) self._run_service_in_background(task, service_call)
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
return False
def _run_service_in_background(
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
) -> None:
"""Run service call in background, catching and logging any exceptions."""
async def catch_exceptions() -> None:
try:
await coro_or_task
except Unauthorized:
_LOGGER.warning(
"Unauthorized service called %s/%s",
service_call.domain,
service_call.service,
)
except asyncio.CancelledError:
_LOGGER.debug("Service was cancelled: %s", service_call)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service: %s", service_call)
self._hass.async_create_task(catch_exceptions())
async def _execute_service( async def _execute_service(
self, handler: Service, service_call: ServiceCall self, handler: Service, service_call: ServiceCall

View File

@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass):
assert len(runs) == 3 assert len(runs) == 3
@pytest.mark.parametrize("cancel_call", [True, False])
async def test_cancel_service_task(hass, cancel_call):
"""Test cancellation."""
service_called = asyncio.Event()
service_cancelled = False
async def service_handler(call):
nonlocal service_cancelled
service_called.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
service_cancelled = True
raise
hass.services.async_register("test_domain", "test_service", service_handler)
call_task = hass.async_create_task(
hass.services.async_call("test_domain", "test_service", blocking=True)
)
tasks_1 = asyncio.all_tasks()
await asyncio.wait_for(service_called.wait(), timeout=1)
tasks_2 = asyncio.all_tasks() - tasks_1
assert len(tasks_2) == 1
service_task = tasks_2.pop()
if cancel_call:
call_task.cancel()
else:
service_task.cancel()
with pytest.raises(asyncio.CancelledError):
await call_task
assert service_cancelled
def test_valid_entity_id(): def test_valid_entity_id():
"""Test valid entity ID.""" """Test valid entity ID."""
for invalid in [ for invalid in [