Small cleanups to service calls (#95873)

This commit is contained in:
J. Nick Koston 2023-07-05 02:25:38 -05:00 committed by GitHub
parent 9109b5fead
commit 91f334ca59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1693,7 +1693,7 @@ class Service:
class ServiceCall: class ServiceCall:
"""Representation of a call to a service.""" """Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context", "return_response"] __slots__ = ("domain", "service", "data", "context", "return_response")
def __init__( def __init__(
self, self,
@ -1704,8 +1704,8 @@ class ServiceCall:
return_response: bool = False, return_response: bool = False,
) -> None: ) -> None:
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain
self.service = service.lower() self.service = service
self.data = ReadOnlyDict(data or {}) self.data = ReadOnlyDict(data or {})
self.context = context or Context() self.context = context or Context()
self.return_response = return_response self.return_response = return_response
@ -1890,15 +1890,20 @@ class ServiceRegistry:
This method is a coroutine. This method is a coroutine.
""" """
domain = domain.lower()
service = service.lower()
context = context or Context() context = context or Context()
service_data = service_data or {} service_data = service_data or {}
try: try:
handler = self._services[domain][service] handler = self._services[domain][service]
except KeyError: except KeyError:
raise ServiceNotFound(domain, service) from None # Almost all calls are already lower case, so we avoid
# calling lower() on the arguments in the common case.
domain = domain.lower()
service = service.lower()
try:
handler = self._services[domain][service]
except KeyError:
raise ServiceNotFound(domain, service) from None
if return_response: if return_response:
if not blocking: if not blocking:
@ -1938,8 +1943,8 @@ class ServiceRegistry:
self._hass.bus.async_fire( self._hass.bus.async_fire(
EVENT_CALL_SERVICE, EVENT_CALL_SERVICE,
{ {
ATTR_DOMAIN: domain.lower(), ATTR_DOMAIN: domain,
ATTR_SERVICE: service.lower(), ATTR_SERVICE: service,
ATTR_SERVICE_DATA: service_data, ATTR_SERVICE_DATA: service_data,
}, },
context=context, context=context,
@ -1947,7 +1952,10 @@ class ServiceRegistry:
coro = self._execute_service(handler, service_call) coro = self._execute_service(handler, service_call)
if not blocking: if not blocking:
self._run_service_in_background(coro, service_call) self._hass.async_create_task(
self._run_service_call_catch_exceptions(coro, service_call),
f"service call background {service_call.domain}.{service_call.service}",
)
return None return None
response_data = await coro response_data = await coro
@ -1959,49 +1967,42 @@ class ServiceRegistry:
) )
return response_data return response_data
def _run_service_in_background( async def _run_service_call_catch_exceptions(
self, self,
coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any], coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any],
service_call: ServiceCall, service_call: ServiceCall,
) -> None: ) -> None:
"""Run service call in background, catching and logging any exceptions.""" """Run service call in background, catching and logging any exceptions."""
try:
async def catch_exceptions() -> None: await coro_or_task
try: except Unauthorized:
await coro_or_task _LOGGER.warning(
except Unauthorized: "Unauthorized service called %s/%s",
_LOGGER.warning( service_call.domain,
"Unauthorized service called %s/%s", service_call.service,
service_call.domain, )
service_call.service, except asyncio.CancelledError:
) _LOGGER.debug("Service was cancelled: %s", service_call)
except asyncio.CancelledError: except Exception: # pylint: disable=broad-except
_LOGGER.debug("Service was cancelled: %s", service_call) _LOGGER.exception("Error executing service: %s", service_call)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service: %s", service_call)
self._hass.async_create_task(
catch_exceptions(),
f"service call background {service_call.domain}.{service_call.service}",
)
async def _execute_service( async def _execute_service(
self, handler: Service, service_call: ServiceCall self, handler: Service, service_call: ServiceCall
) -> ServiceResponse: ) -> ServiceResponse:
"""Execute a service.""" """Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction: job = handler.job
return await cast( target = job.target
Callable[[ServiceCall], Awaitable[ServiceResponse]], if job.job_type == HassJobType.Coroutinefunction:
handler.job.target, if TYPE_CHECKING:
)(service_call) target = cast(Callable[..., Coroutine[Any, Any, _R]], target)
if handler.job.job_type == HassJobType.Callback: return await target(service_call)
return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)( if job.job_type == HassJobType.Callback:
service_call if TYPE_CHECKING:
) target = cast(Callable[..., _R], target)
return await self._hass.async_add_executor_job( return target(service_call)
cast(Callable[[ServiceCall], ServiceResponse], handler.job.target), if TYPE_CHECKING:
service_call, target = cast(Callable[..., _R], target)
) return await self._hass.async_add_executor_job(target, service_call)
class Config: class Config: