Improve service response data APIs (#94819)

* Improve service response data APIs

Make the API naming more consistent, and require registration that a
service supports response data so that we can better integrate with
the UI and avoid user confusion with better error messages.

* Improve test coverage

* Add an enum for registering response values

* Assign enum values

* Convert SupportsResponse to StrEnum

* Update service call test docstrings

* Add tiny missing full stop in comment

---------

Co-authored-by: Franck Nijhof <frenck@frenck.nl>
This commit is contained in:
Allen Porter 2023-06-20 06:24:31 -07:00 committed by GitHub
parent 4a8adae146
commit 30e8f806c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 147 additions and 78 deletions

View File

@ -131,7 +131,7 @@ DOMAIN = "homeassistant"
# How long to wait to log tasks that are blocking # How long to wait to log tasks that are blocking
BLOCK_LOG_TIMEOUT = 60 BLOCK_LOG_TIMEOUT = 60
ServiceResult = JsonObjectType | None ServiceResponse = JsonObjectType | None
class ConfigSource(StrEnum): class ConfigSource(StrEnum):
@ -1655,28 +1655,43 @@ class StateMachine:
) )
class SupportsResponse(StrEnum):
"""Service call response configuration."""
NONE = "none"
"""The service does not support responses (the default)."""
OPTIONAL = "optional"
"""The service optionally returns response data when asked by the caller."""
ONLY = "only"
"""The service is read-only and the caller must always ask for response data."""
class Service: class Service:
"""Representation of a callable service.""" """Representation of a callable service."""
__slots__ = ["job", "schema", "domain", "service"] __slots__ = ["job", "schema", "domain", "service", "supports_response"]
def __init__( def __init__(
self, self,
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None], func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
schema: vol.Schema | None, schema: vol.Schema | None,
domain: str, domain: str,
service: str, service: str,
context: Context | None = None, context: Context | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None: ) -> None:
"""Initialize a service.""" """Initialize a service."""
self.job = HassJob(func, f"service {domain}.{service}") self.job = HassJob(func, f"service {domain}.{service}")
self.schema = schema self.schema = schema
self.supports_response = supports_response
class ServiceCall: class ServiceCall:
"""Representation of a call to a service.""" """Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context", "return_values"] __slots__ = ["domain", "service", "data", "context", "return_response"]
def __init__( def __init__(
self, self,
@ -1684,14 +1699,14 @@ class ServiceCall:
service: str, service: str,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
context: Context | None = None, context: Context | None = None,
return_values: bool = False, return_response: bool = False,
) -> None: ) -> None:
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain.lower()
self.service = service.lower() self.service = service.lower()
self.data = ReadOnlyDict(data or {}) self.data = ReadOnlyDict(data or {})
self.context = context or Context() self.context = context or Context()
self.return_values = return_values self.return_response = return_response
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation of the service.""" """Return the representation of the service."""
@ -1738,7 +1753,7 @@ class ServiceRegistry:
service: str, service: str,
service_func: Callable[ service_func: Callable[
[ServiceCall], [ServiceCall],
Coroutine[Any, Any, ServiceResult] | None, Coroutine[Any, Any, ServiceResponse] | None,
], ],
schema: vol.Schema | None = None, schema: vol.Schema | None = None,
) -> None: ) -> None:
@ -1756,9 +1771,10 @@ class ServiceRegistry:
domain: str, domain: str,
service: str, service: str,
service_func: Callable[ service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None [ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
], ],
schema: vol.Schema | None = None, schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None: ) -> None:
"""Register a service. """Register a service.
@ -1768,7 +1784,9 @@ class ServiceRegistry:
""" """
domain = domain.lower() domain = domain.lower()
service = service.lower() service = service.lower()
service_obj = Service(service_func, schema, domain, service) service_obj = Service(
service_func, schema, domain, service, supports_response=supports_response
)
if domain in self._services: if domain in self._services:
self._services[domain][service] = service_obj self._services[domain][service] = service_obj
@ -1815,8 +1833,8 @@ class ServiceRegistry:
blocking: bool = False, blocking: bool = False,
context: Context | None = None, context: Context | None = None,
target: dict[str, Any] | None = None, target: dict[str, Any] | None = None,
return_values: bool = False, return_response: bool = False,
) -> ServiceResult: ) -> ServiceResponse:
"""Call a service. """Call a service.
See description of async_call for details. See description of async_call for details.
@ -1829,7 +1847,7 @@ class ServiceRegistry:
blocking, blocking,
context, context,
target, target,
return_values, return_response,
), ),
self._hass.loop, self._hass.loop,
).result() ).result()
@ -1842,13 +1860,13 @@ class ServiceRegistry:
blocking: bool = False, blocking: bool = False,
context: Context | None = None, context: Context | None = None,
target: dict[str, Any] | None = None, target: dict[str, Any] | None = None,
return_values: bool = False, return_response: bool = False,
) -> ServiceResult: ) -> ServiceResponse:
"""Call a service. """Call a service.
Specify blocking=True to wait until service is executed. Specify blocking=True to wait until service is executed.
If return_values=True, indicates that the caller can consume return values If return_response=True, indicates that the caller can consume return values
from the service, if any. Return values are a dict that can be returned by the from the service, if any. Return values are a dict that can be returned by the
standard JSON serialization process. Return values can only be used with blocking=True. standard JSON serialization process. Return values can only be used with blocking=True.
@ -1864,14 +1882,25 @@ class ServiceRegistry:
context = context or Context() context = context or Context()
service_data = service_data or {} service_data = service_data or {}
if return_values and not blocking:
raise ValueError("Invalid argument return_values=True when blocking=False")
try: try:
handler = self._services[domain][service] handler = self._services[domain][service]
except KeyError: except KeyError:
raise ServiceNotFound(domain, service) from None raise ServiceNotFound(domain, service) from None
if return_response:
if not blocking:
raise ValueError(
"Invalid argument return_response=True when blocking=False"
)
if handler.supports_response == SupportsResponse.NONE:
raise ValueError(
"Invalid argument return_response=True when handler does not support responses"
)
elif handler.supports_response == SupportsResponse.ONLY:
raise ValueError(
"Service call requires responses but caller did not ask for responses"
)
if target: if target:
service_data.update(target) service_data.update(target)
@ -1890,7 +1919,7 @@ class ServiceRegistry:
processed_data = service_data processed_data = service_data
service_call = ServiceCall( service_call = ServiceCall(
domain, service, processed_data, context, return_values domain, service, processed_data, context, return_response
) )
self._hass.bus.async_fire( self._hass.bus.async_fire(
@ -1909,7 +1938,7 @@ class ServiceRegistry:
return None return None
response_data = await coro response_data = await coro
if not return_values: if not return_response:
return None return None
if not isinstance(response_data, dict): if not isinstance(response_data, dict):
raise HomeAssistantError( raise HomeAssistantError(
@ -1945,19 +1974,19 @@ class ServiceRegistry:
async def _execute_service( async def _execute_service(
self, handler: Service, service_call: ServiceCall self, handler: Service, service_call: ServiceCall
) -> ServiceResult: ) -> ServiceResponse:
"""Execute a service.""" """Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction: if handler.job.job_type == HassJobType.Coroutinefunction:
return await cast( return await cast(
Callable[[ServiceCall], Awaitable[ServiceResult]], Callable[[ServiceCall], Awaitable[ServiceResponse]],
handler.job.target, handler.job.target,
)(service_call) )(service_call)
if handler.job.job_type == HassJobType.Callback: if handler.job.job_type == HassJobType.Callback:
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)( return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)(
service_call service_call
) )
return await self._hass.async_add_executor_job( return await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], ServiceResult], handler.job.target), cast(Callable[[ServiceCall], ServiceResponse], handler.job.target),
service_call, service_call,
) )

View File

@ -674,7 +674,7 @@ class _ScriptRun:
**params, **params,
blocking=True, blocking=True,
context=self._context, context=self._context,
return_values=(response_variable is not None), return_response=(response_variable is not None),
) )
), ),
) )

View File

@ -27,7 +27,7 @@ from homeassistant.core import (
CoreState, CoreState,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResult, ServiceResponse,
callback, callback,
) )
from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound
@ -330,19 +330,19 @@ async def test_calling_service_template(hass: HomeAssistant) -> None:
) )
async def test_calling_service_return_values( async def test_calling_service_response_data(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test the calling of a service with return values.""" """Test the calling of a service with return values."""
context = Context() context = Context()
def mock_service(call: ServiceCall) -> ServiceResult: def mock_service(call: ServiceCall) -> ServiceResponse:
"""Mock service call.""" """Mock service call."""
if call.return_values: if call.return_response:
return {"data": "value-12345"} return {"data": "value-12345"}
return None return None
hass.services.async_register("test", "script", mock_service) hass.services.async_register("test", "script", mock_service, supports_response=True)
sequence = cv.SCRIPT_SCHEMA( sequence = cv.SCRIPT_SCHEMA(
[ [
{ {

View File

@ -33,7 +33,14 @@ from homeassistant.const import (
__version__, __version__,
) )
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State from homeassistant.core import (
HassJob,
HomeAssistant,
ServiceCall,
ServiceResponse,
State,
SupportsResponse,
)
from homeassistant.exceptions import ( from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
InvalidEntityFormatError, InvalidEntityFormatError,
@ -1083,58 +1090,44 @@ async def test_serviceregistry_callback_service_raise_exception(
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"""Test service call for a service that has return values.""" "supports_response",
[
SupportsResponse.ONLY,
SupportsResponse.OPTIONAL,
],
)
async def test_serviceregistry_async_return_response(
hass: HomeAssistant, supports_response: SupportsResponse
) -> None:
"""Test service call for a service that returns response data."""
def service_handler(call: ServiceCall) -> ServiceResult: async def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine.""" """Service handler coroutine."""
assert call.return_values assert call.return_response
return {"test-reply": "test-value1"} return {"test-reply": "test-value1"}
hass.services.async_register( hass.services.async_register(
"test_domain", "test_domain",
"test_service", "test_service",
service_handler, service_handler,
supports_response=supports_response,
) )
result = await hass.services.async_call( result = await hass.services.async_call(
"test_domain", "test_domain",
"test_service", "test_service",
service_data={}, service_data={},
blocking=True, blocking=True,
return_values=True, return_response=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"} assert result == {"test-reply": "test-value1"}
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None: async def test_services_call_return_response_requires_blocking(
"""Test service call for an async service that has return values."""
async def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_services_call_return_values_requires_blocking(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test that non-blocking service calls cannot return values.""" """Test that non-blocking service calls cannot ask for response data."""
async_mock_service(hass, "test_domain", "test_service") async_mock_service(hass, "test_domain", "test_service")
with pytest.raises(ValueError, match="when blocking=False"): with pytest.raises(ValueError, match="when blocking=False"):
await hass.services.async_call( await hass.services.async_call(
@ -1142,12 +1135,12 @@ async def test_services_call_return_values_requires_blocking(
"test_service", "test_service",
service_data={}, service_data={},
blocking=False, blocking=False,
return_values=True, return_response=True,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("return_value", "expected_error"), ("response_data", "expected_error"),
[ [
(True, "expected a dictionary"), (True, "expected a dictionary"),
(False, "expected a dictionary"), (False, "expected a dictionary"),
@ -1156,20 +1149,21 @@ async def test_services_call_return_values_requires_blocking(
(["some-list"], "expected a dictionary"), (["some-list"], "expected a dictionary"),
], ],
) )
async def test_serviceregistry_return_values_invalid( async def test_serviceregistry_return_response_invalid(
hass: HomeAssistant, return_value: Any, expected_error: str hass: HomeAssistant, response_data: Any, expected_error: str
) -> None: ) -> None:
"""Test service call return values are not returned when there is no result schema.""" """Test service call response data must be json serializable objects."""
def service_handler(call: ServiceCall) -> ServiceResult: def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine.""" """Service handler coroutine."""
assert call.return_values assert call.return_response
return return_value return response_data
hass.services.async_register( hass.services.async_register(
"test_domain", "test_domain",
"test_service", "test_service",
service_handler, service_handler,
supports_response=SupportsResponse.ONLY,
) )
with pytest.raises(HomeAssistantError, match=expected_error): with pytest.raises(HomeAssistantError, match=expected_error):
await hass.services.async_call( await hass.services.async_call(
@ -1177,32 +1171,78 @@ async def test_serviceregistry_return_values_invalid(
"test_service", "test_service",
service_data={}, service_data={},
blocking=True, blocking=True,
return_values=True, return_response=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"""Test service call data when not asked for return values.""" ("supports_response", "return_response", "expected_error"),
[
(SupportsResponse.NONE, True, "not support responses"),
(SupportsResponse.ONLY, False, "caller did not ask for responses"),
],
)
async def test_serviceregistry_return_response_arguments(
hass: HomeAssistant,
supports_response: SupportsResponse,
return_response: bool,
expected_error: str,
) -> None:
"""Test service call response data invalid arguments."""
def service_handler(call: ServiceCall) -> None: hass.services.async_register(
"test_domain",
"test_service",
"service_handler",
supports_response=supports_response,
)
with pytest.raises(ValueError, match=expected_error):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_response=return_response,
)
@pytest.mark.parametrize(
("return_response", "expected_response_data"),
[
(True, {"key": "value"}),
(False, None),
],
)
async def test_serviceregistry_return_response_optional(
hass: HomeAssistant,
return_response: bool,
expected_response_data: Any,
) -> None:
"""Test optional service call response data."""
def service_handler(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine.""" """Service handler coroutine."""
assert not call.return_values if call.return_response:
return return {"key": "value"}
return None
hass.services.async_register( hass.services.async_register(
"test_domain", "test_domain",
"test_service", "test_service",
service_handler, service_handler,
supports_response=SupportsResponse.OPTIONAL,
) )
result = await hass.services.async_call( response_data = await hass.services.async_call(
"test_domain", "test_domain",
"test_service", "test_service",
service_data={}, service_data={},
blocking=True, blocking=True,
return_response=return_response,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert not result assert response_data == expected_response_data
async def test_config_defaults() -> None: async def test_config_defaults() -> None: