mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Improve service response data APIs (#94819)
* Improve service response data APIs Make the API naming more consistent, and require registration that a service supports response data so that we can better integrate with the UI and avoid user confusion with better error messages. * Improve test coverage * Add an enum for registering response values * Assign enum values * Convert SupportsResponse to StrEnum * Update service call test docstrings * Add tiny missing full stop in comment --------- Co-authored-by: Franck Nijhof <frenck@frenck.nl>
This commit is contained in:
parent
4a8adae146
commit
30e8f806c1
@ -131,7 +131,7 @@ DOMAIN = "homeassistant"
|
||||
# How long to wait to log tasks that are blocking
|
||||
BLOCK_LOG_TIMEOUT = 60
|
||||
|
||||
ServiceResult = JsonObjectType | None
|
||||
ServiceResponse = JsonObjectType | None
|
||||
|
||||
|
||||
class ConfigSource(StrEnum):
|
||||
@ -1655,28 +1655,43 @@ class StateMachine:
|
||||
)
|
||||
|
||||
|
||||
class SupportsResponse(StrEnum):
|
||||
"""Service call response configuration."""
|
||||
|
||||
NONE = "none"
|
||||
"""The service does not support responses (the default)."""
|
||||
|
||||
OPTIONAL = "optional"
|
||||
"""The service optionally returns response data when asked by the caller."""
|
||||
|
||||
ONLY = "only"
|
||||
"""The service is read-only and the caller must always ask for response data."""
|
||||
|
||||
|
||||
class Service:
|
||||
"""Representation of a callable service."""
|
||||
|
||||
__slots__ = ["job", "schema", "domain", "service"]
|
||||
__slots__ = ["job", "schema", "domain", "service", "supports_response"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
|
||||
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
|
||||
schema: vol.Schema | None,
|
||||
domain: str,
|
||||
service: str,
|
||||
context: Context | None = None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Initialize a service."""
|
||||
self.job = HassJob(func, f"service {domain}.{service}")
|
||||
self.schema = schema
|
||||
self.supports_response = supports_response
|
||||
|
||||
|
||||
class ServiceCall:
|
||||
"""Representation of a call to a service."""
|
||||
|
||||
__slots__ = ["domain", "service", "data", "context", "return_values"]
|
||||
__slots__ = ["domain", "service", "data", "context", "return_response"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1684,14 +1699,14 @@ class ServiceCall:
|
||||
service: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
context: Context | None = None,
|
||||
return_values: bool = False,
|
||||
return_response: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = ReadOnlyDict(data or {})
|
||||
self.context = context or Context()
|
||||
self.return_values = return_values
|
||||
self.return_response = return_response
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the service."""
|
||||
@ -1738,7 +1753,7 @@ class ServiceRegistry:
|
||||
service: str,
|
||||
service_func: Callable[
|
||||
[ServiceCall],
|
||||
Coroutine[Any, Any, ServiceResult] | None,
|
||||
Coroutine[Any, Any, ServiceResponse] | None,
|
||||
],
|
||||
schema: vol.Schema | None = None,
|
||||
) -> None:
|
||||
@ -1756,9 +1771,10 @@ class ServiceRegistry:
|
||||
domain: str,
|
||||
service: str,
|
||||
service_func: Callable[
|
||||
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
|
||||
[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
|
||||
],
|
||||
schema: vol.Schema | None = None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Register a service.
|
||||
|
||||
@ -1768,7 +1784,9 @@ class ServiceRegistry:
|
||||
"""
|
||||
domain = domain.lower()
|
||||
service = service.lower()
|
||||
service_obj = Service(service_func, schema, domain, service)
|
||||
service_obj = Service(
|
||||
service_func, schema, domain, service, supports_response=supports_response
|
||||
)
|
||||
|
||||
if domain in self._services:
|
||||
self._services[domain][service] = service_obj
|
||||
@ -1815,8 +1833,8 @@ class ServiceRegistry:
|
||||
blocking: bool = False,
|
||||
context: Context | None = None,
|
||||
target: dict[str, Any] | None = None,
|
||||
return_values: bool = False,
|
||||
) -> ServiceResult:
|
||||
return_response: bool = False,
|
||||
) -> ServiceResponse:
|
||||
"""Call a service.
|
||||
|
||||
See description of async_call for details.
|
||||
@ -1829,7 +1847,7 @@ class ServiceRegistry:
|
||||
blocking,
|
||||
context,
|
||||
target,
|
||||
return_values,
|
||||
return_response,
|
||||
),
|
||||
self._hass.loop,
|
||||
).result()
|
||||
@ -1842,13 +1860,13 @@ class ServiceRegistry:
|
||||
blocking: bool = False,
|
||||
context: Context | None = None,
|
||||
target: dict[str, Any] | None = None,
|
||||
return_values: bool = False,
|
||||
) -> ServiceResult:
|
||||
return_response: bool = False,
|
||||
) -> ServiceResponse:
|
||||
"""Call a service.
|
||||
|
||||
Specify blocking=True to wait until service is executed.
|
||||
|
||||
If return_values=True, indicates that the caller can consume return values
|
||||
If return_response=True, indicates that the caller can consume return values
|
||||
from the service, if any. Return values are a dict that can be returned by the
|
||||
standard JSON serialization process. Return values can only be used with blocking=True.
|
||||
|
||||
@ -1864,14 +1882,25 @@ class ServiceRegistry:
|
||||
context = context or Context()
|
||||
service_data = service_data or {}
|
||||
|
||||
if return_values and not blocking:
|
||||
raise ValueError("Invalid argument return_values=True when blocking=False")
|
||||
|
||||
try:
|
||||
handler = self._services[domain][service]
|
||||
except KeyError:
|
||||
raise ServiceNotFound(domain, service) from None
|
||||
|
||||
if return_response:
|
||||
if not blocking:
|
||||
raise ValueError(
|
||||
"Invalid argument return_response=True when blocking=False"
|
||||
)
|
||||
if handler.supports_response == SupportsResponse.NONE:
|
||||
raise ValueError(
|
||||
"Invalid argument return_response=True when handler does not support responses"
|
||||
)
|
||||
elif handler.supports_response == SupportsResponse.ONLY:
|
||||
raise ValueError(
|
||||
"Service call requires responses but caller did not ask for responses"
|
||||
)
|
||||
|
||||
if target:
|
||||
service_data.update(target)
|
||||
|
||||
@ -1890,7 +1919,7 @@ class ServiceRegistry:
|
||||
processed_data = service_data
|
||||
|
||||
service_call = ServiceCall(
|
||||
domain, service, processed_data, context, return_values
|
||||
domain, service, processed_data, context, return_response
|
||||
)
|
||||
|
||||
self._hass.bus.async_fire(
|
||||
@ -1909,7 +1938,7 @@ class ServiceRegistry:
|
||||
return None
|
||||
|
||||
response_data = await coro
|
||||
if not return_values:
|
||||
if not return_response:
|
||||
return None
|
||||
if not isinstance(response_data, dict):
|
||||
raise HomeAssistantError(
|
||||
@ -1945,19 +1974,19 @@ class ServiceRegistry:
|
||||
|
||||
async def _execute_service(
|
||||
self, handler: Service, service_call: ServiceCall
|
||||
) -> ServiceResult:
|
||||
) -> ServiceResponse:
|
||||
"""Execute a service."""
|
||||
if handler.job.job_type == HassJobType.Coroutinefunction:
|
||||
return await cast(
|
||||
Callable[[ServiceCall], Awaitable[ServiceResult]],
|
||||
Callable[[ServiceCall], Awaitable[ServiceResponse]],
|
||||
handler.job.target,
|
||||
)(service_call)
|
||||
if handler.job.job_type == HassJobType.Callback:
|
||||
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
|
||||
return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)(
|
||||
service_call
|
||||
)
|
||||
return await self._hass.async_add_executor_job(
|
||||
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
|
||||
cast(Callable[[ServiceCall], ServiceResponse], handler.job.target),
|
||||
service_call,
|
||||
)
|
||||
|
||||
|
@ -674,7 +674,7 @@ class _ScriptRun:
|
||||
**params,
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
return_values=(response_variable is not None),
|
||||
return_response=(response_variable is not None),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from homeassistant.core import (
|
||||
CoreState,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResult,
|
||||
ServiceResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound
|
||||
@ -330,19 +330,19 @@ async def test_calling_service_template(hass: HomeAssistant) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def test_calling_service_return_values(
|
||||
async def test_calling_service_response_data(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test the calling of a service with return values."""
|
||||
context = Context()
|
||||
|
||||
def mock_service(call: ServiceCall) -> ServiceResult:
|
||||
def mock_service(call: ServiceCall) -> ServiceResponse:
|
||||
"""Mock service call."""
|
||||
if call.return_values:
|
||||
if call.return_response:
|
||||
return {"data": "value-12345"}
|
||||
return None
|
||||
|
||||
hass.services.async_register("test", "script", mock_service)
|
||||
hass.services.async_register("test", "script", mock_service, supports_response=True)
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
[
|
||||
{
|
||||
|
@ -33,7 +33,14 @@ from homeassistant.const import (
|
||||
__version__,
|
||||
)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
|
||||
from homeassistant.core import (
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
State,
|
||||
SupportsResponse,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
InvalidEntityFormatError,
|
||||
@ -1083,58 +1090,44 @@ async def test_serviceregistry_callback_service_raise_exception(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call for a service that has return values."""
|
||||
@pytest.mark.parametrize(
|
||||
"supports_response",
|
||||
[
|
||||
SupportsResponse.ONLY,
|
||||
SupportsResponse.OPTIONAL,
|
||||
],
|
||||
)
|
||||
async def test_serviceregistry_async_return_response(
|
||||
hass: HomeAssistant, supports_response: SupportsResponse
|
||||
) -> None:
|
||||
"""Test service call for a service that returns response data."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
async def service_handler(call: ServiceCall) -> ServiceResponse:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
assert call.return_response
|
||||
return {"test-reply": "test-value1"}
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
supports_response=supports_response,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
return_response=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result == {"test-reply": "test-value1"}
|
||||
|
||||
|
||||
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call for an async service that has return values."""
|
||||
|
||||
async def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
return {"test-reply": "test-value1"}
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result == {"test-reply": "test-value1"}
|
||||
|
||||
|
||||
async def test_services_call_return_values_requires_blocking(
|
||||
async def test_services_call_return_response_requires_blocking(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that non-blocking service calls cannot return values."""
|
||||
"""Test that non-blocking service calls cannot ask for response data."""
|
||||
async_mock_service(hass, "test_domain", "test_service")
|
||||
with pytest.raises(ValueError, match="when blocking=False"):
|
||||
await hass.services.async_call(
|
||||
@ -1142,12 +1135,12 @@ async def test_services_call_return_values_requires_blocking(
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=False,
|
||||
return_values=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("return_value", "expected_error"),
|
||||
("response_data", "expected_error"),
|
||||
[
|
||||
(True, "expected a dictionary"),
|
||||
(False, "expected a dictionary"),
|
||||
@ -1156,20 +1149,21 @@ async def test_services_call_return_values_requires_blocking(
|
||||
(["some-list"], "expected a dictionary"),
|
||||
],
|
||||
)
|
||||
async def test_serviceregistry_return_values_invalid(
|
||||
hass: HomeAssistant, return_value: Any, expected_error: str
|
||||
async def test_serviceregistry_return_response_invalid(
|
||||
hass: HomeAssistant, response_data: Any, expected_error: str
|
||||
) -> None:
|
||||
"""Test service call return values are not returned when there is no result schema."""
|
||||
"""Test service call response data must be json serializable objects."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
def service_handler(call: ServiceCall) -> ServiceResponse:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
return return_value
|
||||
assert call.return_response
|
||||
return response_data
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
with pytest.raises(HomeAssistantError, match=expected_error):
|
||||
await hass.services.async_call(
|
||||
@ -1177,32 +1171,78 @@ async def test_serviceregistry_return_values_invalid(
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
return_response=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call data when not asked for return values."""
|
||||
@pytest.mark.parametrize(
|
||||
("supports_response", "return_response", "expected_error"),
|
||||
[
|
||||
(SupportsResponse.NONE, True, "not support responses"),
|
||||
(SupportsResponse.ONLY, False, "caller did not ask for responses"),
|
||||
],
|
||||
)
|
||||
async def test_serviceregistry_return_response_arguments(
|
||||
hass: HomeAssistant,
|
||||
supports_response: SupportsResponse,
|
||||
return_response: bool,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test service call response data invalid arguments."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> None:
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
"service_handler",
|
||||
supports_response=supports_response,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_response=return_response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("return_response", "expected_response_data"),
|
||||
[
|
||||
(True, {"key": "value"}),
|
||||
(False, None),
|
||||
],
|
||||
)
|
||||
async def test_serviceregistry_return_response_optional(
|
||||
hass: HomeAssistant,
|
||||
return_response: bool,
|
||||
expected_response_data: Any,
|
||||
) -> None:
|
||||
"""Test optional service call response data."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> ServiceResponse:
|
||||
"""Service handler coroutine."""
|
||||
assert not call.return_values
|
||||
return
|
||||
if call.return_response:
|
||||
return {"key": "value"}
|
||||
return None
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
supports_response=SupportsResponse.OPTIONAL,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
response_data = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_response=return_response,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert not result
|
||||
assert response_data == expected_response_data
|
||||
|
||||
|
||||
async def test_config_defaults() -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user