Optionally return response data when calling services through the API (#115046)

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Jack Gaino 2024-07-31 15:00:04 -04:00 committed by GitHub
parent 17f34b452e
commit 2910369647
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 1 deletions

View File

@ -390,6 +390,27 @@ class APIDomainServicesView(HomeAssistantView):
) )
context = self.context(request) context = self.context(request)
if not hass.services.has_service(domain, service):
raise HTTPBadRequest from ServiceNotFound(domain, service)
if response_requested := "return_response" in request.query:
if (
hass.services.supports_response(domain, service)
is ha.SupportsResponse.NONE
):
return self.json_message(
"Service does not support responses. Remove return_response from request.",
HTTPStatus.BAD_REQUEST,
)
elif (
hass.services.supports_response(domain, service) is ha.SupportsResponse.ONLY
):
return self.json_message(
"Service call requires responses but caller did not ask for responses. "
"Add ?return_response to query parameters.",
HTTPStatus.BAD_REQUEST,
)
changed_states: list[json_fragment] = [] changed_states: list[json_fragment] = []
@ha.callback @ha.callback
@ -406,13 +427,14 @@ class APIDomainServicesView(HomeAssistantView):
try: try:
# shield the service call from cancellation on connection drop # shield the service call from cancellation on connection drop
await shield( response = await shield(
hass.services.async_call( hass.services.async_call(
domain, domain,
service, service,
data, # type: ignore[arg-type] data, # type: ignore[arg-type]
blocking=True, blocking=True,
context=context, context=context,
return_response=response_requested,
) )
) )
except (vol.Invalid, ServiceNotFound) as ex: except (vol.Invalid, ServiceNotFound) as ex:
@ -420,6 +442,11 @@ class APIDomainServicesView(HomeAssistantView):
finally: finally:
cancel_listen() cancel_listen()
if response_requested:
return self.json(
{"changed_states": changed_states, "service_response": response}
)
return self.json(changed_states) return self.json(changed_states)

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
import json import json
from typing import Any
from unittest.mock import patch from unittest.mock import patch
from aiohttp import ServerDisconnectedError, web from aiohttp import ServerDisconnectedError, web
@ -355,6 +356,67 @@ async def test_api_call_service_with_data(
assert state["attributes"] == {"data": 1} assert state["attributes"] == {"data": 1}
SERVICE_DICT = {"changed_states": [], "service_response": {"foo": "bar"}}
RESP_REQUIRED = {
"message": (
"Service call requires responses but caller did not ask for "
"responses. Add ?return_response to query parameters."
)
}
RESP_UNSUPPORTED = {
"message": "Service does not support responses. Remove return_response from request."
}
@pytest.mark.parametrize(
(
"supports_response",
"requested_response",
"expected_number_of_service_calls",
"expected_status",
"expected_response",
),
[
(ha.SupportsResponse.ONLY, True, 1, HTTPStatus.OK, SERVICE_DICT),
(ha.SupportsResponse.ONLY, False, 0, HTTPStatus.BAD_REQUEST, RESP_REQUIRED),
(ha.SupportsResponse.OPTIONAL, True, 1, HTTPStatus.OK, SERVICE_DICT),
(ha.SupportsResponse.OPTIONAL, False, 1, HTTPStatus.OK, []),
(ha.SupportsResponse.NONE, True, 0, HTTPStatus.BAD_REQUEST, RESP_UNSUPPORTED),
(ha.SupportsResponse.NONE, False, 1, HTTPStatus.OK, []),
],
)
async def test_api_call_service_returns_response_requested_response(
hass: HomeAssistant,
mock_api_client: TestClient,
supports_response: ha.SupportsResponse,
requested_response: bool,
expected_number_of_service_calls: int,
expected_status: int,
expected_response: Any,
) -> None:
"""Test if the API allows us to call a service."""
test_value = []
@ha.callback
def listener(service_call):
"""Record that our service got called."""
test_value.append(1)
return {"foo": "bar"}
hass.services.async_register(
"test_domain", "test_service", listener, supports_response=supports_response
)
resp = await mock_api_client.post(
"/api/services/test_domain/test_service"
+ ("?return_response" if requested_response else "")
)
assert resp.status == expected_status
await hass.async_block_till_done()
assert len(test_value) == expected_number_of_service_calls
assert await resp.json() == expected_response
async def test_api_call_service_client_closed( async def test_api_call_service_client_closed(
hass: HomeAssistant, mock_api_client: TestClient hass: HomeAssistant, mock_api_client: TestClient
) -> None: ) -> None: