Add service response support to admin services (#144837)

This commit is contained in:
Abílio Costa 2025-05-13 21:57:15 +01:00 committed by GitHub
parent de2cbb7f5c
commit 6d809b0b5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 7 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
import dataclasses import dataclasses
from enum import Enum from enum import Enum
from functools import cache, partial from functools import cache, partial
@ -1094,9 +1094,15 @@ async def _handle_entity_call(
async def _async_admin_handler( async def _async_admin_handler(
hass: HomeAssistant, hass: HomeAssistant,
service_job: HassJob[[ServiceCall], Awaitable[None] | None], service_job: HassJob[
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse]
| ServiceResponse
| EntityServiceResponse
| None,
],
call: ServiceCall, call: ServiceCall,
) -> None: ) -> ServiceResponse | EntityServiceResponse | None:
"""Run an admin service.""" """Run an admin service."""
if call.context.user_id: if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
@ -1105,9 +1111,10 @@ async def _async_admin_handler(
if not user.is_admin: if not user.is_admin:
raise Unauthorized(context=call.context) raise Unauthorized(context=call.context)
result = hass.async_run_hass_job(service_job, call) task = hass.async_run_hass_job(service_job, call)
if result is not None: if task is not None:
await result return await task
return None
@bind_hass @bind_hass
@ -1116,8 +1123,15 @@ def async_register_admin_service(
hass: HomeAssistant, hass: HomeAssistant,
domain: str, domain: str,
service: str, service: str,
service_func: Callable[[ServiceCall], Awaitable[None] | None], service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse]
| ServiceResponse
| EntityServiceResponse
| None,
],
schema: VolSchemaType = vol.Schema({}, extra=vol.PREVENT_EXTRA), schema: VolSchemaType = vol.Schema({}, extra=vol.PREVENT_EXTRA),
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None: ) -> None:
"""Register a service that requires admin access.""" """Register a service that requires admin access."""
hass.services.async_register( hass.services.async_register(
@ -1129,6 +1143,7 @@ def async_register_admin_service(
HassJob(service_func, f"admin service {domain}.{service}"), HassJob(service_func, f"admin service {domain}.{service}"),
), ),
schema, schema,
supports_response,
) )

View File

@ -32,6 +32,7 @@ from homeassistant.core import (
HassJob, HassJob,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse,
SupportsResponse, SupportsResponse,
) )
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -1648,6 +1649,33 @@ async def test_register_admin_service(
assert calls[0].context.user_id == hass_admin_user.id assert calls[0].context.user_id == hass_admin_user.id
@pytest.mark.parametrize(
"supports_response",
[SupportsResponse.ONLY, SupportsResponse.OPTIONAL],
)
async def test_register_admin_service_return_response(
hass: HomeAssistant, supports_response: SupportsResponse
) -> None:
"""Test the register admin service for a service that returns response data."""
async def mock_service(call: ServiceCall) -> ServiceResponse:
"""Service handler coroutine."""
assert call.return_response
return {"test-reply": "test-value1"}
service.async_register_admin_service(
hass, "test", "test", mock_service, supports_response=supports_response
)
result = await hass.services.async_call(
"test",
"test",
service_data={},
blocking=True,
return_response=True,
)
assert result == {"test-reply": "test-value1"}
async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> None: async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> None:
"""Test domain verification in a service call with an unknown user.""" """Test domain verification in a service call with an unknown user."""
calls = [] calls = []