Improve service response data APIs (#94819)

* Improve service response data APIs

Make the API naming more consistent, and require registration that a
service supports response data so that we can better integrate with
the UI and avoid user confusion with better error messages.

* Improve test coverage

* Add an enum for registering response values

* Assign enum values

* Convert SupportsResponse to StrEnum

* Update service call test docstrings

* Add tiny missing full stop in comment

---------

Co-authored-by: Franck Nijhof <frenck@frenck.nl>
This commit is contained in:
Allen Porter 2023-06-20 06:24:31 -07:00 committed by GitHub
parent 4a8adae146
commit 30e8f806c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 147 additions and 78 deletions

View File

@ -131,7 +131,7 @@ DOMAIN = "homeassistant"
# How long to wait to log tasks that are blocking
BLOCK_LOG_TIMEOUT = 60
ServiceResult = JsonObjectType | None
ServiceResponse = JsonObjectType | None
class ConfigSource(StrEnum):
@ -1655,28 +1655,43 @@ class StateMachine:
)
class SupportsResponse(StrEnum):
"""Service call response configuration."""
NONE = "none"
"""The service does not support responses (the default)."""
OPTIONAL = "optional"
"""The service optionally returns response data when asked by the caller."""
ONLY = "only"
"""The service is read-only and the caller must always ask for response data."""
class Service:
"""Representation of a callable service."""
__slots__ = ["job", "schema", "domain", "service"]
__slots__ = ["job", "schema", "domain", "service", "supports_response"]
def __init__(
self,
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
schema: vol.Schema | None,
domain: str,
service: str,
context: Context | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Initialize a service."""
self.job = HassJob(func, f"service {domain}.{service}")
self.schema = schema
self.supports_response = supports_response
class ServiceCall:
"""Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context", "return_values"]
__slots__ = ["domain", "service", "data", "context", "return_response"]
def __init__(
self,
@ -1684,14 +1699,14 @@ class ServiceCall:
service: str,
data: dict[str, Any] | None = None,
context: Context | None = None,
return_values: bool = False,
return_response: bool = False,
) -> None:
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = ReadOnlyDict(data or {})
self.context = context or Context()
self.return_values = return_values
self.return_response = return_response
def __repr__(self) -> str:
"""Return the representation of the service."""
@ -1738,7 +1753,7 @@ class ServiceRegistry:
service: str,
service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResult] | None,
Coroutine[Any, Any, ServiceResponse] | None,
],
schema: vol.Schema | None = None,
) -> None:
@ -1756,9 +1771,10 @@ class ServiceRegistry:
domain: str,
service: str,
service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
],
schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Register a service.
@ -1768,7 +1784,9 @@ class ServiceRegistry:
"""
domain = domain.lower()
service = service.lower()
service_obj = Service(service_func, schema, domain, service)
service_obj = Service(
service_func, schema, domain, service, supports_response=supports_response
)
if domain in self._services:
self._services[domain][service] = service_obj
@ -1815,8 +1833,8 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
return_values: bool = False,
) -> ServiceResult:
return_response: bool = False,
) -> ServiceResponse:
"""Call a service.
See description of async_call for details.
@ -1829,7 +1847,7 @@ class ServiceRegistry:
blocking,
context,
target,
return_values,
return_response,
),
self._hass.loop,
).result()
@ -1842,13 +1860,13 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
return_values: bool = False,
) -> ServiceResult:
return_response: bool = False,
) -> ServiceResponse:
"""Call a service.
Specify blocking=True to wait until service is executed.
If return_values=True, indicates that the caller can consume return values
If return_response=True, indicates that the caller can consume return values
from the service, if any. Return values are a dict that can be returned by the
standard JSON serialization process. Return values can only be used with blocking=True.
@ -1864,14 +1882,25 @@ class ServiceRegistry:
context = context or Context()
service_data = service_data or {}
if return_values and not blocking:
raise ValueError("Invalid argument return_values=True when blocking=False")
try:
handler = self._services[domain][service]
except KeyError:
raise ServiceNotFound(domain, service) from None
if return_response:
if not blocking:
raise ValueError(
"Invalid argument return_response=True when blocking=False"
)
if handler.supports_response == SupportsResponse.NONE:
raise ValueError(
"Invalid argument return_response=True when handler does not support responses"
)
elif handler.supports_response == SupportsResponse.ONLY:
raise ValueError(
"Service call requires responses but caller did not ask for responses"
)
if target:
service_data.update(target)
@ -1890,7 +1919,7 @@ class ServiceRegistry:
processed_data = service_data
service_call = ServiceCall(
domain, service, processed_data, context, return_values
domain, service, processed_data, context, return_response
)
self._hass.bus.async_fire(
@ -1909,7 +1938,7 @@ class ServiceRegistry:
return None
response_data = await coro
if not return_values:
if not return_response:
return None
if not isinstance(response_data, dict):
raise HomeAssistantError(
@ -1945,19 +1974,19 @@ class ServiceRegistry:
async def _execute_service(
self, handler: Service, service_call: ServiceCall
) -> ServiceResult:
) -> ServiceResponse:
"""Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction:
return await cast(
Callable[[ServiceCall], Awaitable[ServiceResult]],
Callable[[ServiceCall], Awaitable[ServiceResponse]],
handler.job.target,
)(service_call)
if handler.job.job_type == HassJobType.Callback:
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)(
service_call
)
return await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
cast(Callable[[ServiceCall], ServiceResponse], handler.job.target),
service_call,
)

View File

@ -674,7 +674,7 @@ class _ScriptRun:
**params,
blocking=True,
context=self._context,
return_values=(response_variable is not None),
return_response=(response_variable is not None),
)
),
)

View File

@ -27,7 +27,7 @@ from homeassistant.core import (
CoreState,
HomeAssistant,
ServiceCall,
ServiceResult,
ServiceResponse,
callback,
)
from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound
@ -330,19 +330,19 @@ async def test_calling_service_template(hass: HomeAssistant) -> None:
)
async def test_calling_service_return_values(
async def test_calling_service_response_data(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test the calling of a service with return values."""
context = Context()
def mock_service(call: ServiceCall) -> ServiceResult:
def mock_service(call: ServiceCall) -> ServiceResponse:
"""Mock service call."""
if call.return_values:
if call.return_response:
return {"data": "value-12345"}
return None
hass.services.async_register("test", "script", mock_service)
hass.services.async_register("test", "script", mock_service, supports_response=True)
sequence = cv.SCRIPT_SCHEMA(
[
{

View File

@ -33,7 +33,14 @@ from homeassistant.const import (
__version__,
)
import homeassistant.core as ha
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
from homeassistant.core import (
HassJob,
HomeAssistant,
ServiceCall,
ServiceResponse,
State,
SupportsResponse,
)
from homeassistant.exceptions import (
HomeAssistantError,
InvalidEntityFormatError,
@ -1083,58 +1090,44 @@ async def test_serviceregistry_callback_service_raise_exception(
await hass.async_block_till_done()
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
"""Test service call for a service that has return values."""
@pytest.mark.parametrize(
"supports_response",
[
SupportsResponse.ONLY,
SupportsResponse.OPTIONAL,
],
)
async def test_serviceregistry_async_return_response(
hass: HomeAssistant, supports_response: SupportsResponse
) -> None:
"""Test service call for a service that returns response data."""
def service_handler(call: ServiceCall) -> ServiceResult:
async def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine."""
assert call.return_values
assert call.return_response
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
supports_response=supports_response,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
return_response=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
"""Test service call for an async service that has return values."""
async def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_services_call_return_values_requires_blocking(
async def test_services_call_return_response_requires_blocking(
hass: HomeAssistant,
) -> None:
"""Test that non-blocking service calls cannot return values."""
"""Test that non-blocking service calls cannot ask for response data."""
async_mock_service(hass, "test_domain", "test_service")
with pytest.raises(ValueError, match="when blocking=False"):
await hass.services.async_call(
@ -1142,12 +1135,12 @@ async def test_services_call_return_values_requires_blocking(
"test_service",
service_data={},
blocking=False,
return_values=True,
return_response=True,
)
@pytest.mark.parametrize(
("return_value", "expected_error"),
("response_data", "expected_error"),
[
(True, "expected a dictionary"),
(False, "expected a dictionary"),
@ -1156,20 +1149,21 @@ async def test_services_call_return_values_requires_blocking(
(["some-list"], "expected a dictionary"),
],
)
async def test_serviceregistry_return_values_invalid(
hass: HomeAssistant, return_value: Any, expected_error: str
async def test_serviceregistry_return_response_invalid(
hass: HomeAssistant, response_data: Any, expected_error: str
) -> None:
"""Test service call return values are not returned when there is no result schema."""
"""Test service call response data must be json serializable objects."""
def service_handler(call: ServiceCall) -> ServiceResult:
def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine."""
assert call.return_values
return return_value
assert call.return_response
return response_data
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
supports_response=SupportsResponse.ONLY,
)
with pytest.raises(HomeAssistantError, match=expected_error):
await hass.services.async_call(
@ -1177,32 +1171,78 @@ async def test_serviceregistry_return_values_invalid(
"test_service",
service_data={},
blocking=True,
return_values=True,
return_response=True,
)
await hass.async_block_till_done()
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
"""Test service call data when not asked for return values."""
@pytest.mark.parametrize(
("supports_response", "return_response", "expected_error"),
[
(SupportsResponse.NONE, True, "not support responses"),
(SupportsResponse.ONLY, False, "caller did not ask for responses"),
],
)
async def test_serviceregistry_return_response_arguments(
hass: HomeAssistant,
supports_response: SupportsResponse,
return_response: bool,
expected_error: str,
) -> None:
"""Test service call response data invalid arguments."""
def service_handler(call: ServiceCall) -> None:
hass.services.async_register(
"test_domain",
"test_service",
"service_handler",
supports_response=supports_response,
)
with pytest.raises(ValueError, match=expected_error):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_response=return_response,
)
@pytest.mark.parametrize(
("return_response", "expected_response_data"),
[
(True, {"key": "value"}),
(False, None),
],
)
async def test_serviceregistry_return_response_optional(
hass: HomeAssistant,
return_response: bool,
expected_response_data: Any,
) -> None:
"""Test optional service call response data."""
def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine."""
assert not call.return_values
return
if call.return_response:
return {"key": "value"}
return None
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
supports_response=SupportsResponse.OPTIONAL,
)
result = await hass.services.async_call(
response_data = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_response=return_response,
)
await hass.async_block_till_done()
assert not result
assert response_data == expected_response_data
async def test_config_defaults() -> None: