diff --git a/homeassistant/core.py b/homeassistant/core.py index 6405b0860e1..ad5fb44a514 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -131,7 +131,7 @@ DOMAIN = "homeassistant" # How long to wait to log tasks that are blocking BLOCK_LOG_TIMEOUT = 60 -ServiceResult = JsonObjectType | None +ServiceResponse = JsonObjectType | None class ConfigSource(StrEnum): @@ -1655,28 +1655,43 @@ class StateMachine: ) +class SupportsResponse(StrEnum): + """Service call response configuration.""" + + NONE = "none" + """The service does not support responses (the default).""" + + OPTIONAL = "optional" + """The service optionally returns response data when asked by the caller.""" + + ONLY = "only" + """The service is read-only and the caller must always ask for response data.""" + + class Service: """Representation of a callable service.""" - __slots__ = ["job", "schema", "domain", "service"] + __slots__ = ["job", "schema", "domain", "service", "supports_response"] def __init__( self, - func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None], + func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None], schema: vol.Schema | None, domain: str, service: str, context: Context | None = None, + supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Initialize a service.""" self.job = HassJob(func, f"service {domain}.{service}") self.schema = schema + self.supports_response = supports_response class ServiceCall: """Representation of a call to a service.""" - __slots__ = ["domain", "service", "data", "context", "return_values"] + __slots__ = ["domain", "service", "data", "context", "return_response"] def __init__( self, @@ -1684,14 +1699,14 @@ class ServiceCall: service: str, data: dict[str, Any] | None = None, context: Context | None = None, - return_values: bool = False, + return_response: bool = False, ) -> None: """Initialize a service call.""" self.domain = domain.lower() self.service = service.lower() self.data = ReadOnlyDict(data or {}) self.context = context or Context() - self.return_values = return_values + self.return_response = return_response def __repr__(self) -> str: """Return the representation of the service.""" @@ -1738,7 +1753,7 @@ class ServiceRegistry: service: str, service_func: Callable[ [ServiceCall], - Coroutine[Any, Any, ServiceResult] | None, + Coroutine[Any, Any, ServiceResponse] | None, ], schema: vol.Schema | None = None, ) -> None: @@ -1756,9 +1771,10 @@ class ServiceRegistry: domain: str, service: str, service_func: Callable[ - [ServiceCall], Coroutine[Any, Any, ServiceResult] | None + [ServiceCall], Coroutine[Any, Any, ServiceResponse] | None ], schema: vol.Schema | None = None, + supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Register a service. @@ -1768,7 +1784,9 @@ class ServiceRegistry: """ domain = domain.lower() service = service.lower() - service_obj = Service(service_func, schema, domain, service) + service_obj = Service( + service_func, schema, domain, service, supports_response=supports_response + ) if domain in self._services: self._services[domain][service] = service_obj @@ -1815,8 +1833,8 @@ class ServiceRegistry: blocking: bool = False, context: Context | None = None, target: dict[str, Any] | None = None, - return_values: bool = False, - ) -> ServiceResult: + return_response: bool = False, + ) -> ServiceResponse: """Call a service. See description of async_call for details. @@ -1829,7 +1847,7 @@ class ServiceRegistry: blocking, context, target, - return_values, + return_response, ), self._hass.loop, ).result() @@ -1842,13 +1860,13 @@ class ServiceRegistry: blocking: bool = False, context: Context | None = None, target: dict[str, Any] | None = None, - return_values: bool = False, - ) -> ServiceResult: + return_response: bool = False, + ) -> ServiceResponse: """Call a service. Specify blocking=True to wait until service is executed. - If return_values=True, indicates that the caller can consume return values + If return_response=True, indicates that the caller can consume return values from the service, if any. Return values are a dict that can be returned by the standard JSON serialization process. Return values can only be used with blocking=True. @@ -1864,14 +1882,25 @@ class ServiceRegistry: context = context or Context() service_data = service_data or {} - if return_values and not blocking: - raise ValueError("Invalid argument return_values=True when blocking=False") - try: handler = self._services[domain][service] except KeyError: raise ServiceNotFound(domain, service) from None + if return_response: + if not blocking: + raise ValueError( + "Invalid argument return_response=True when blocking=False" + ) + if handler.supports_response == SupportsResponse.NONE: + raise ValueError( + "Invalid argument return_response=True when handler does not support responses" + ) + elif handler.supports_response == SupportsResponse.ONLY: + raise ValueError( + "Service call requires responses but caller did not ask for responses" + ) + if target: service_data.update(target) @@ -1890,7 +1919,7 @@ class ServiceRegistry: processed_data = service_data service_call = ServiceCall( - domain, service, processed_data, context, return_values + domain, service, processed_data, context, return_response ) self._hass.bus.async_fire( @@ -1909,7 +1938,7 @@ class ServiceRegistry: return None response_data = await coro - if not return_values: + if not return_response: return None if not isinstance(response_data, dict): raise HomeAssistantError( @@ -1945,19 +1974,19 @@ class ServiceRegistry: async def _execute_service( self, handler: Service, service_call: ServiceCall - ) -> ServiceResult: + ) -> ServiceResponse: """Execute a service.""" if handler.job.job_type == HassJobType.Coroutinefunction: return await cast( - Callable[[ServiceCall], Awaitable[ServiceResult]], + Callable[[ServiceCall], Awaitable[ServiceResponse]], handler.job.target, )(service_call) if handler.job.job_type == HassJobType.Callback: - return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)( + return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)( service_call ) return await self._hass.async_add_executor_job( - cast(Callable[[ServiceCall], ServiceResult], handler.job.target), + cast(Callable[[ServiceCall], ServiceResponse], handler.job.target), service_call, ) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index b876affb9e6..ee4346ff388 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -674,7 +674,7 @@ class _ScriptRun: **params, blocking=True, context=self._context, - return_values=(response_variable is not None), + return_response=(response_variable is not None), ) ), ) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 0868bb5a0cc..de13557024a 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -27,7 +27,7 @@ from homeassistant.core import ( CoreState, HomeAssistant, ServiceCall, - ServiceResult, + ServiceResponse, callback, ) from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound @@ -330,19 +330,19 @@ async def test_calling_service_template(hass: HomeAssistant) -> None: ) -async def test_calling_service_return_values( +async def test_calling_service_response_data( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test the calling of a service with return values.""" context = Context() - def mock_service(call: ServiceCall) -> ServiceResult: + def mock_service(call: ServiceCall) -> ServiceResponse: """Mock service call.""" - if call.return_values: + if call.return_response: return {"data": "value-12345"} return None - hass.services.async_register("test", "script", mock_service) + hass.services.async_register("test", "script", mock_service, supports_response=True) sequence = cv.SCRIPT_SCHEMA( [ { diff --git a/tests/test_core.py b/tests/test_core.py index ebc5718c7cb..8b63eab7b42 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -33,7 +33,14 @@ from homeassistant.const import ( __version__, ) import homeassistant.core as ha -from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State +from homeassistant.core import ( + HassJob, + HomeAssistant, + ServiceCall, + ServiceResponse, + State, + SupportsResponse, +) from homeassistant.exceptions import ( HomeAssistantError, InvalidEntityFormatError, @@ -1083,58 +1090,44 @@ async def test_serviceregistry_callback_service_raise_exception( await hass.async_block_till_done() -async def test_serviceregistry_return_values(hass: HomeAssistant) -> None: - """Test service call for a service that has return values.""" +@pytest.mark.parametrize( + "supports_response", + [ + SupportsResponse.ONLY, + SupportsResponse.OPTIONAL, + ], +) +async def test_serviceregistry_async_return_response( + hass: HomeAssistant, supports_response: SupportsResponse +) -> None: + """Test service call for a service that returns response data.""" - def service_handler(call: ServiceCall) -> ServiceResult: + async def service_handler(call: ServiceCall) -> ServiceResponse: """Service handler coroutine.""" - assert call.return_values + assert call.return_response return {"test-reply": "test-value1"} hass.services.async_register( "test_domain", "test_service", service_handler, + supports_response=supports_response, ) result = await hass.services.async_call( "test_domain", "test_service", service_data={}, blocking=True, - return_values=True, + return_response=True, ) await hass.async_block_till_done() assert result == {"test-reply": "test-value1"} -async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None: - """Test service call for an async service that has return values.""" - - async def service_handler(call: ServiceCall) -> ServiceResult: - """Service handler coroutine.""" - assert call.return_values - return {"test-reply": "test-value1"} - - hass.services.async_register( - "test_domain", - "test_service", - service_handler, - ) - result = await hass.services.async_call( - "test_domain", - "test_service", - service_data={}, - blocking=True, - return_values=True, - ) - await hass.async_block_till_done() - assert result == {"test-reply": "test-value1"} - - -async def test_services_call_return_values_requires_blocking( +async def test_services_call_return_response_requires_blocking( hass: HomeAssistant, ) -> None: - """Test that non-blocking service calls cannot return values.""" + """Test that non-blocking service calls cannot ask for response data.""" async_mock_service(hass, "test_domain", "test_service") with pytest.raises(ValueError, match="when blocking=False"): await hass.services.async_call( @@ -1142,12 +1135,12 @@ async def test_services_call_return_values_requires_blocking( "test_service", service_data={}, blocking=False, - return_values=True, + return_response=True, ) @pytest.mark.parametrize( - ("return_value", "expected_error"), + ("response_data", "expected_error"), [ (True, "expected a dictionary"), (False, "expected a dictionary"), @@ -1156,20 +1149,21 @@ async def test_services_call_return_values_requires_blocking( (["some-list"], "expected a dictionary"), ], ) -async def test_serviceregistry_return_values_invalid( - hass: HomeAssistant, return_value: Any, expected_error: str +async def test_serviceregistry_return_response_invalid( + hass: HomeAssistant, response_data: Any, expected_error: str ) -> None: - """Test service call return values are not returned when there is no result schema.""" + """Test service call response data must be json serializable objects.""" - def service_handler(call: ServiceCall) -> ServiceResult: + def service_handler(call: ServiceCall) -> ServiceResponse: """Service handler coroutine.""" - assert call.return_values - return return_value + assert call.return_response + return response_data hass.services.async_register( "test_domain", "test_service", service_handler, + supports_response=SupportsResponse.ONLY, ) with pytest.raises(HomeAssistantError, match=expected_error): await hass.services.async_call( @@ -1177,32 +1171,78 @@ async def test_serviceregistry_return_values_invalid( "test_service", service_data={}, blocking=True, - return_values=True, + return_response=True, ) await hass.async_block_till_done() -async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None: - """Test service call data when not asked for return values.""" +@pytest.mark.parametrize( + ("supports_response", "return_response", "expected_error"), + [ + (SupportsResponse.NONE, True, "not support responses"), + (SupportsResponse.ONLY, False, "caller did not ask for responses"), + ], +) +async def test_serviceregistry_return_response_arguments( + hass: HomeAssistant, + supports_response: SupportsResponse, + return_response: bool, + expected_error: str, +) -> None: + """Test service call response data invalid arguments.""" - def service_handler(call: ServiceCall) -> None: + hass.services.async_register( + "test_domain", + "test_service", + "service_handler", + supports_response=supports_response, + ) + + with pytest.raises(ValueError, match=expected_error): + await hass.services.async_call( + "test_domain", + "test_service", + service_data={}, + blocking=True, + return_response=return_response, + ) + + +@pytest.mark.parametrize( + ("return_response", "expected_response_data"), + [ + (True, {"key": "value"}), + (False, None), + ], +) +async def test_serviceregistry_return_response_optional( + hass: HomeAssistant, + return_response: bool, + expected_response_data: Any, +) -> None: + """Test optional service call response data.""" + + def service_handler(call: ServiceCall) -> ServiceResponse: """Service handler coroutine.""" - assert not call.return_values - return + if call.return_response: + return {"key": "value"} + return None hass.services.async_register( "test_domain", "test_service", service_handler, + supports_response=SupportsResponse.OPTIONAL, ) - result = await hass.services.async_call( + response_data = await hass.services.async_call( "test_domain", "test_service", service_data={}, blocking=True, + return_response=return_response, ) await hass.async_block_till_done() - assert not result + assert response_data == expected_response_data async def test_config_defaults() -> None: