From 06c9719cd6bbdadd4e591bbce385927702128a63 Mon Sep 17 00:00:00 2001 From: Kevin Stillhammer Date: Fri, 3 Nov 2023 02:37:35 +0100 Subject: [PATCH] Support multiple responses for service calls (#96370) * add supports_response to platform entity services * support multiple entities in entity_service_call * support legacy response format for service calls * revert changes to script/shell_command * add back test for multiple responses for legacy service * remove SupportsResponse.ONLY_LEGACY * Apply suggestion Co-authored-by: Allen Porter * test for entity_id remove None * revert Apply suggestion * return EntityServiceResponse from _handle_entity_call * Use asyncio.gather * EntityServiceResponse not Optional * styling --------- Co-authored-by: Allen Porter --- homeassistant/components/calendar/__init__.py | 2 +- homeassistant/components/weather/__init__.py | 2 +- homeassistant/core.py | 9 +- homeassistant/helpers/entity_component.py | 39 +++++- homeassistant/helpers/entity_platform.py | 9 +- homeassistant/helpers/service.py | 42 +++--- tests/helpers/test_entity_component.py | 87 +++++++++++- tests/helpers/test_entity_platform.py | 124 +++++++++++++++++- 8 files changed, 277 insertions(+), 37 deletions(-) diff --git a/homeassistant/components/calendar/__init__.py b/homeassistant/components/calendar/__init__.py index 65a61e71d3a..2be0bd9a04b 100644 --- a/homeassistant/components/calendar/__init__.py +++ b/homeassistant/components/calendar/__init__.py @@ -300,7 +300,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async_create_event, required_features=[CalendarEntityFeature.CREATE_EVENT], ) - component.async_register_entity_service( + component.async_register_legacy_entity_service( SERVICE_LIST_EVENTS, SERVICE_LIST_EVENTS_SCHEMA, async_list_events_service, diff --git a/homeassistant/components/weather/__init__.py b/homeassistant/components/weather/__init__.py index 648201f16d2..d04daf2b160 100644 --- a/homeassistant/components/weather/__init__.py +++ b/homeassistant/components/weather/__init__.py @@ -210,7 +210,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: component = hass.data[DOMAIN] = EntityComponent[WeatherEntity]( _LOGGER, DOMAIN, hass, SCAN_INTERVAL ) - component.async_register_entity_service( + component.async_register_legacy_entity_service( SERVICE_GET_FORECAST, {vol.Required("type"): vol.In(("daily", "hourly", "twice_daily"))}, async_get_forecast_service, diff --git a/homeassistant/core.py b/homeassistant/core.py index 01a3dd7fbe6..40e9da376d5 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -134,6 +134,7 @@ DOMAIN = "homeassistant" BLOCK_LOG_TIMEOUT = 60 ServiceResponse = JsonObjectType | None +EntityServiceResponse = dict[str, ServiceResponse] class ConfigSource(enum.StrEnum): @@ -1773,7 +1774,10 @@ class Service: def __init__( self, - func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None], + func: Callable[ + [ServiceCall], + Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None, + ], schema: vol.Schema | None, domain: str, service: str, @@ -1882,7 +1886,8 @@ class ServiceRegistry: domain: str, service: str, service_func: Callable[ - [ServiceCall], Coroutine[Any, Any, ServiceResponse] | None + [ServiceCall], + Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None, ], schema: vol.Schema | None = None, supports_response: SupportsResponse = SupportsResponse.NONE, diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index af1b87ec0fa..ddd46759259 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -20,6 +20,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import ( + EntityServiceResponse, Event, HomeAssistant, ServiceCall, @@ -217,6 +218,40 @@ class EntityComponent(Generic[_EntityT]): self.hass, self.entities, service_call, expand_group ) + @callback + def async_register_legacy_entity_service( + self, + name: str, + schema: dict[str | vol.Marker, Any] | vol.Schema, + func: str | Callable[..., Any], + required_features: list[int] | None = None, + supports_response: SupportsResponse = SupportsResponse.NONE, + ) -> None: + """Register an entity service with a legacy response format.""" + if isinstance(schema, dict): + schema = cv.make_entity_service_schema(schema) + + async def handle_service( + call: ServiceCall, + ) -> ServiceResponse: + """Handle the service.""" + + result = await service.entity_service_call( + self.hass, self._platforms.values(), func, call, required_features + ) + + if result: + if len(result) > 1: + raise HomeAssistantError( + "Deprecated service call matched more than one entity" + ) + return result.popitem()[1] + return None + + self.hass.services.async_register( + self.domain, name, handle_service, schema, supports_response + ) + @callback def async_register_entity_service( self, @@ -230,7 +265,9 @@ class EntityComponent(Generic[_EntityT]): if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) - async def handle_service(call: ServiceCall) -> ServiceResponse: + async def handle_service( + call: ServiceCall, + ) -> EntityServiceResponse | None: """Handle the service.""" return await service.entity_service_call( self.hass, self._platforms.values(), func, call, required_features diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index c164e3b1052..388c00bd177 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -20,8 +20,10 @@ from homeassistant.core import ( CALLBACK_TYPE, DOMAIN as HOMEASSISTANT_DOMAIN, CoreState, + EntityServiceResponse, HomeAssistant, ServiceCall, + SupportsResponse, callback, split_entity_id, valid_entity_id, @@ -814,6 +816,7 @@ class EntityPlatform: schema: dict[str, Any] | vol.Schema, func: str | Callable[..., Any], required_features: Iterable[int] | None = None, + supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Register an entity service. @@ -825,9 +828,9 @@ class EntityPlatform: if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) - async def handle_service(call: ServiceCall) -> None: + async def handle_service(call: ServiceCall) -> EntityServiceResponse | None: """Handle the service.""" - await service.entity_service_call( + return await service.entity_service_call( self.hass, [ plf @@ -840,7 +843,7 @@ class EntityPlatform: ) self.hass.services.async_register( - self.platform_name, name, handle_service, schema + self.platform_name, name, handle_service, schema, supports_response ) async def _update_entity_states(self, now: datetime) -> None: diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 4532e1a00ae..4cb8852414b 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -28,6 +28,7 @@ from homeassistant.const import ( ) from homeassistant.core import ( Context, + EntityServiceResponse, HomeAssistant, ServiceCall, ServiceResponse, @@ -790,7 +791,7 @@ async def entity_service_call( func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], call: ServiceCall, required_features: Iterable[int] | None = None, -) -> ServiceResponse | None: +) -> EntityServiceResponse | None: """Handle an entity service call. Calls all platforms simultaneously. @@ -870,10 +871,9 @@ async def entity_service_call( return None if len(entities) == 1: - # Single entity case avoids creating tasks and allows returning - # ServiceResponse + # Single entity case avoids creating task entity = entities[0] - response_data = await _handle_entity_call( + single_response = await _handle_entity_call( hass, entity, func, data, call.context ) if entity.should_poll: @@ -881,27 +881,25 @@ async def entity_service_call( # Set context again so it's there when we update entity.async_set_context(call.context) await entity.async_update_ha_state(True) - return response_data if return_response else None + return {entity.entity_id: single_response} if return_response else None - if return_response: - raise HomeAssistantError( - "Service call requested response data but matched more than one entity" - ) - - done, pending = await asyncio.wait( - [ - asyncio.create_task( - entity.async_request_call( - _handle_entity_call(hass, entity, func, data, call.context) - ) + # Use asyncio.gather here to ensure the returned results + # are in the same order as the entities list + results: list[ServiceResponse] = await asyncio.gather( + *[ + entity.async_request_call( + _handle_entity_call(hass, entity, func, data, call.context) ) for entity in entities - ] + ], + return_exceptions=True, ) - assert not pending - for task in done: - task.result() # pop exception if have + response_data: EntityServiceResponse = {} + for entity, result in zip(entities, results): + if isinstance(result, Exception): + raise result + response_data[entity.entity_id] = result tasks: list[asyncio.Task[None]] = [] @@ -920,7 +918,7 @@ async def entity_service_call( for future in done: future.result() # pop exception if have - return None + return response_data if return_response and response_data else None async def _handle_entity_call( @@ -943,7 +941,7 @@ async def _handle_entity_call( # Guard because callback functions do not return a task when passed to # async_run_job. - result: ServiceResponse | None = None + result: ServiceResponse = None if task is not None: result = await task diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 4119ccc6e85..b5cda6770c5 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -531,7 +531,7 @@ async def test_register_entity_service(hass: HomeAssistant) -> None: async def test_register_entity_service_response_data(hass: HomeAssistant) -> None: - """Test an enttiy service that does not support response data.""" + """Test an entity service that does support response data.""" entity = MockEntity(entity_id=f"{DOMAIN}.entity") async def generate_response( @@ -554,24 +554,25 @@ async def test_register_entity_service_response_data(hass: HomeAssistant) -> Non response_data = await hass.services.async_call( DOMAIN, "hello", - service_data={"entity_id": entity.entity_id, "some": "data"}, + service_data={"some": "data"}, + target={"entity_id": [entity.entity_id]}, blocking=True, return_response=True, ) - assert response_data == {"response-key": "response-value"} + assert response_data == {f"{DOMAIN}.entity": {"response-key": "response-value"}} async def test_register_entity_service_response_data_multiple_matches( hass: HomeAssistant, ) -> None: - """Test asking for service response data but matching many entities.""" + """Test asking for service response data and matching many entities.""" entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1") entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2") async def generate_response( target: MockEntity, call: ServiceCall ) -> ServiceResponse: - raise ValueError("Should not be invoked") + return {"response-key": f"response-value-{target.entity_id}"} component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_setup({}) @@ -579,7 +580,80 @@ async def test_register_entity_service_response_data_multiple_matches( component.async_register_entity_service( "hello", - {}, + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + + response_data = await hass.services.async_call( + DOMAIN, + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + assert response_data == { + f"{DOMAIN}.entity1": {"response-key": f"response-value-{DOMAIN}.entity1"}, + f"{DOMAIN}.entity2": {"response-key": f"response-value-{DOMAIN}.entity2"}, + } + + +async def test_register_entity_service_response_data_multiple_matches_raises( + hass: HomeAssistant, +) -> None: + """Test asking for service response data and matching many entities raises exceptions.""" + entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1") + entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2") + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + if target.entity_id == f"{DOMAIN}.entity1": + raise RuntimeError("Something went wrong") + return {"response-key": f"response-value-{target.entity_id}"} + + component = EntityComponent(_LOGGER, DOMAIN, hass) + await component.async_setup({}) + await component.async_add_entities([entity1, entity2]) + + component.async_register_entity_service( + "hello", + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + + with pytest.raises(RuntimeError, match="Something went wrong"): + await hass.services.async_call( + DOMAIN, + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + + +async def test_legacy_register_entity_service_response_data_multiple_matches( + hass: HomeAssistant, +) -> None: + """Test asking for legacy service response data but matching many entities.""" + entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1") + entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2") + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + return {"response-key": "response-value"} + + component = EntityComponent(_LOGGER, DOMAIN, hass) + await component.async_setup({}) + await component.async_add_entities([entity1, entity2]) + + component.async_register_legacy_entity_service( + "hello", + {"some": str}, generate_response, supports_response=SupportsResponse.ONLY, ) @@ -588,6 +662,7 @@ async def test_register_entity_service_response_data_multiple_matches( await hass.services.async_call( DOMAIN, "hello", + service_data={"some": "data"}, target={"entity_id": [entity1.entity_id, entity2.entity_id]}, blocking=True, return_response=True, diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 0bbfedb8926..7ccbd5e0f28 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -9,7 +9,14 @@ from unittest.mock import ANY, Mock, patch import pytest from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, PERCENTAGE -from homeassistant.core import CoreState, HomeAssistant, callback +from homeassistant.core import ( + CoreState, + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, + callback, +) from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.helpers import ( device_registry as dr, @@ -1491,6 +1498,121 @@ async def test_platforms_sharing_services(hass: HomeAssistant) -> None: assert entity2 in entities +async def test_register_entity_service_response_data(hass: HomeAssistant) -> None: + """Test an entity service that does supports response data.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": "response-value"} + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity = MockEntity(entity_id="mock_integration.entity") + await entity_platform.async_add_entities([entity]) + + entity_platform.async_register_entity_service( + "hello", + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity.entity_id]}, + blocking=True, + return_response=True, + ) + assert response_data == { + "mock_integration.entity": {"response-key": "response-value"} + } + + +async def test_register_entity_service_response_data_multiple_matches( + hass: HomeAssistant, +) -> None: + """Test an entity service that does supports response data and matching many entities.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": f"response-value-{target.entity_id}"} + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity1") + entity2 = MockEntity(entity_id="mock_integration.entity2") + await entity_platform.async_add_entities([entity1, entity2]) + + entity_platform.async_register_entity_service( + "hello", + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + assert response_data == { + "mock_integration.entity1": { + "response-key": "response-value-mock_integration.entity1" + }, + "mock_integration.entity2": { + "response-key": "response-value-mock_integration.entity2" + }, + } + + +async def test_register_entity_service_response_data_multiple_matches_raises( + hass: HomeAssistant, +) -> None: + """Test entity service response matching many entities raises.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + if target.entity_id == "mock_integration.entity1": + raise RuntimeError("Something went wrong") + return {"response-key": f"response-value-{target.entity_id}"} + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity1") + entity2 = MockEntity(entity_id="mock_integration.entity2") + await entity_platform.async_add_entities([entity1, entity2]) + + entity_platform.async_register_entity_service( + "hello", + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + with pytest.raises(RuntimeError, match="Something went wrong"): + await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + + async def test_invalid_entity_id(hass: HomeAssistant) -> None: """Test specifying an invalid entity id.""" platform = MockEntityPlatform(hass)