diff --git a/homeassistant/components/group/notify.py b/homeassistant/components/group/notify.py index 378a7852343..2747ba55ee1 100644 --- a/homeassistant/components/group/notify.py +++ b/homeassistant/components/group/notify.py @@ -66,7 +66,7 @@ class GroupNotifyPlatform(BaseNotificationService): payload: dict[str, Any] = {ATTR_MESSAGE: message} payload.update({key: val for key, val in kwargs.items() if val}) - tasks: list[asyncio.Task[bool | None]] = [] + tasks: list[asyncio.Task[Any]] = [] for entity in self.entities: sending_payload = deepcopy(payload.copy()) if (default_data := entity.get(ATTR_DATA)) is not None: @@ -74,7 +74,7 @@ class GroupNotifyPlatform(BaseNotificationService): tasks.append( asyncio.create_task( self.hass.services.async_call( - DOMAIN, entity[ATTR_SERVICE], sending_payload + DOMAIN, entity[ATTR_SERVICE], sending_payload, blocking=True ) ) ) diff --git a/homeassistant/core.py b/homeassistant/core.py index 333f9b82cd2..6405b0860e1 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -88,6 +88,7 @@ from .util.async_ import ( run_callback_threadsafe, shutdown_run_callback_threadsafe, ) +from .util.json import JsonObjectType from .util.read_only_dict import ReadOnlyDict from .util.timeout import TimeoutManager from .util.unit_system import ( @@ -130,6 +131,8 @@ DOMAIN = "homeassistant" # How long to wait to log tasks that are blocking BLOCK_LOG_TIMEOUT = 60 +ServiceResult = JsonObjectType | None + class ConfigSource(StrEnum): """Source of core configuration.""" @@ -1659,7 +1662,7 @@ class Service: def __init__( self, - func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], + func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None], schema: vol.Schema | None, domain: str, service: str, @@ -1673,7 +1676,7 @@ class Service: class ServiceCall: """Representation of a call to a service.""" - __slots__ = ["domain", "service", "data", "context"] + __slots__ = ["domain", "service", "data", "context", "return_values"] def __init__( self, @@ -1681,12 +1684,14 @@ class ServiceCall: service: str, data: dict[str, Any] | None = None, context: Context | None = None, + return_values: bool = False, ) -> None: """Initialize a service call.""" self.domain = domain.lower() self.service = service.lower() self.data = ReadOnlyDict(data or {}) self.context = context or Context() + self.return_values = return_values def __repr__(self) -> str: """Return the representation of the service.""" @@ -1731,7 +1736,10 @@ class ServiceRegistry: self, domain: str, service: str, - service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], + service_func: Callable[ + [ServiceCall], + Coroutine[Any, Any, ServiceResult] | None, + ], schema: vol.Schema | None = None, ) -> None: """Register a service. @@ -1747,7 +1755,9 @@ class ServiceRegistry: self, domain: str, service: str, - service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], + service_func: Callable[ + [ServiceCall], Coroutine[Any, Any, ServiceResult] | None + ], schema: vol.Schema | None = None, ) -> None: """Register a service. @@ -1805,13 +1815,22 @@ class ServiceRegistry: blocking: bool = False, context: Context | None = None, target: dict[str, Any] | None = None, - ) -> bool | None: + return_values: bool = False, + ) -> ServiceResult: """Call a service. See description of async_call for details. """ return asyncio.run_coroutine_threadsafe( - self.async_call(domain, service, service_data, blocking, context, target), + self.async_call( + domain, + service, + service_data, + blocking, + context, + target, + return_values, + ), self._hass.loop, ).result() @@ -1823,11 +1842,16 @@ class ServiceRegistry: blocking: bool = False, context: Context | None = None, target: dict[str, Any] | None = None, - ) -> None: + return_values: bool = False, + ) -> ServiceResult: """Call a service. Specify blocking=True to wait until service is executed. + If return_values=True, indicates that the caller can consume return values + from the service, if any. Return values are a dict that can be returned by the + standard JSON serialization process. Return values can only be used with blocking=True. + This method will fire an event to indicate the service has been called. Because the service is sent as an event you are not allowed to use @@ -1840,6 +1864,9 @@ class ServiceRegistry: context = context or Context() service_data = service_data or {} + if return_values and not blocking: + raise ValueError("Invalid argument return_values=True when blocking=False") + try: handler = self._services[domain][service] except KeyError: @@ -1862,7 +1889,9 @@ class ServiceRegistry: else: processed_data = service_data - service_call = ServiceCall(domain, service, processed_data, context) + service_call = ServiceCall( + domain, service, processed_data, context, return_values + ) self._hass.bus.async_fire( EVENT_CALL_SERVICE, @@ -1877,13 +1906,20 @@ class ServiceRegistry: coro = self._execute_service(handler, service_call) if not blocking: self._run_service_in_background(coro, service_call) - return + return None - await coro + response_data = await coro + if not return_values: + return None + if not isinstance(response_data, dict): + raise HomeAssistantError( + f"Service response data expected a dictionary, was {type(response_data)}" + ) + return response_data def _run_service_in_background( self, - coro_or_task: Coroutine[Any, Any, None] | asyncio.Task[None], + coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any], service_call: ServiceCall, ) -> None: """Run service call in background, catching and logging any exceptions.""" @@ -1909,18 +1945,21 @@ class ServiceRegistry: async def _execute_service( self, handler: Service, service_call: ServiceCall - ) -> None: + ) -> ServiceResult: """Execute a service.""" if handler.job.job_type == HassJobType.Coroutinefunction: - await cast(Callable[[ServiceCall], Awaitable[None]], handler.job.target)( + return await cast( + Callable[[ServiceCall], Awaitable[ServiceResult]], + handler.job.target, + )(service_call) + if handler.job.job_type == HassJobType.Callback: + return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)( service_call ) - elif handler.job.job_type == HassJobType.Callback: - cast(Callable[[ServiceCall], None], handler.job.target)(service_call) - else: - await self._hass.async_add_executor_job( - cast(Callable[[ServiceCall], None], handler.job.target), service_call - ) + return await self._hass.async_add_executor_job( + cast(Callable[[ServiceCall], ServiceResult], handler.job.target), + service_call, + ) class Config: diff --git a/tests/test_core.py b/tests/test_core.py index 2759ca751b5..ebc5718c7cb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -33,8 +33,9 @@ from homeassistant.const import ( __version__, ) import homeassistant.core as ha -from homeassistant.core import HassJob, HomeAssistant, State +from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State from homeassistant.exceptions import ( + HomeAssistantError, InvalidEntityFormatError, InvalidStateError, MaxLengthExceeded, @@ -1082,6 +1083,128 @@ async def test_serviceregistry_callback_service_raise_exception( await hass.async_block_till_done() +async def test_serviceregistry_return_values(hass: HomeAssistant) -> None: + """Test service call for a service that has return values.""" + + def service_handler(call: ServiceCall) -> ServiceResult: + """Service handler coroutine.""" + assert call.return_values + return {"test-reply": "test-value1"} + + hass.services.async_register( + "test_domain", + "test_service", + service_handler, + ) + result = await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=True, + return_values=True, + ) + await hass.async_block_till_done() + assert result == {"test-reply": "test-value1"} + + +async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None: + """Test service call for an async service that has return values.""" + + async def service_handler(call: ServiceCall) -> ServiceResult: + """Service handler coroutine.""" + assert call.return_values + return {"test-reply": "test-value1"} + + hass.services.async_register( + "test_domain", + "test_service", + service_handler, + ) + result = await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=True, + return_values=True, + ) + await hass.async_block_till_done() + assert result == {"test-reply": "test-value1"} + + +async def test_services_call_return_values_requires_blocking( + hass: HomeAssistant, +) -> None: + """Test that non-blocking service calls cannot return values.""" + async_mock_service(hass, "test_domain", "test_service") + with pytest.raises(ValueError, match="when blocking=False"): + await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=False, + return_values=True, + ) + + +@pytest.mark.parametrize( + ("return_value", "expected_error"), + [ + (True, "expected a dictionary"), + (False, "expected a dictionary"), + (None, "expected a dictionary"), + ("some-value", "expected a dictionary"), + (["some-list"], "expected a dictionary"), + ], +) +async def test_serviceregistry_return_values_invalid( + hass: HomeAssistant, return_value: Any, expected_error: str +) -> None: + """Test service call return values are not returned when there is no result schema.""" + + def service_handler(call: ServiceCall) -> ServiceResult: + """Service handler coroutine.""" + assert call.return_values + return return_value + + hass.services.async_register( + "test_domain", + "test_service", + service_handler, + ) + with pytest.raises(HomeAssistantError, match=expected_error): + await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=True, + return_values=True, + ) + await hass.async_block_till_done() + + +async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None: + """Test service call data when not asked for return values.""" + + def service_handler(call: ServiceCall) -> None: + """Service handler coroutine.""" + assert not call.return_values + return + + hass.services.async_register( + "test_domain", + "test_service", + service_handler, + ) + result = await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=True, + ) + await hass.async_block_till_done() + assert not result + + async def test_config_defaults() -> None: """Test config defaults.""" hass = Mock()