From 4b493c5ab951bf978b691fc0ce5e9bae2a739208 Mon Sep 17 00:00:00 2001 From: Bram Kragten Date: Wed, 10 Feb 2021 12:42:28 +0100 Subject: [PATCH] Add target to service call API (#45898) * Add target to service call API * Fix _async_call_service_step * CONF_SERVICE_ENTITY_ID overrules target * Move merging up before processing schema * Restore services.yaml * Add test --- homeassistant/components/api/__init__.py | 2 +- .../components/websocket_api/commands.py | 2 + homeassistant/core.py | 9 +++- homeassistant/helpers/script.py | 12 +++--- homeassistant/helpers/service.py | 32 +++++++++++---- .../components/websocket_api/test_commands.py | 41 +++++++++++++++++++ 6 files changed, 82 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/api/__init__.py b/homeassistant/components/api/__init__.py index e7bac8532ee..a82309094e3 100644 --- a/homeassistant/components/api/__init__.py +++ b/homeassistant/components/api/__init__.py @@ -378,7 +378,7 @@ class APIDomainServicesView(HomeAssistantView): with AsyncTrackStates(hass) as changed_states: try: await hass.services.async_call( - domain, service, data, True, self.context(request) + domain, service, data, blocking=True, context=self.context(request) ) except (vol.Invalid, ServiceNotFound) as ex: raise HTTPBadRequest() from ex diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 77521c1ed98..ddd7548cd68 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg): vol.Required("type"): "call_service", vol.Required("domain"): str, vol.Required("service"): str, + vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS, vol.Optional("service_data"): dict, } ) @@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg): msg.get("service_data"), blocking, context, + target=msg.get("target"), ) connection.send_message( messages.result_message(msg["id"], {"context": context}) diff --git a/homeassistant/core.py b/homeassistant/core.py index 4294eb530a7..13f8b153047 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -1358,6 +1358,7 @@ class ServiceRegistry: blocking: bool = False, context: Optional[Context] = None, limit: Optional[float] = SERVICE_CALL_LIMIT, + target: Optional[Dict] = None, ) -> Optional[bool]: """ Call a service. @@ -1365,7 +1366,9 @@ class ServiceRegistry: See description of async_call for details. """ return asyncio.run_coroutine_threadsafe( - self.async_call(domain, service, service_data, blocking, context, limit), + self.async_call( + domain, service, service_data, blocking, context, limit, target + ), self._hass.loop, ).result() @@ -1377,6 +1380,7 @@ class ServiceRegistry: blocking: bool = False, context: Optional[Context] = None, limit: Optional[float] = SERVICE_CALL_LIMIT, + target: Optional[Dict] = None, ) -> Optional[bool]: """ Call a service. @@ -1404,6 +1408,9 @@ class ServiceRegistry: except KeyError: raise ServiceNotFound(domain, service) from None + if target: + service_data.update(target) + if handler.schema: try: processed_data = handler.schema(service_data) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 56accf9cf49..a0e8311048e 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -433,14 +433,14 @@ class _ScriptRun: self._script.last_action = self._action.get(CONF_ALIAS, "call service") self._log("Executing step %s", self._script.last_action) - domain, service_name, service_data = service.async_prepare_call_from_config( + params = service.async_prepare_call_from_config( self._hass, self._action, self._variables ) running_script = ( - domain == "automation" - and service_name == "trigger" - or domain in ("python_script", "script") + params["domain"] == "automation" + and params["service_name"] == "trigger" + or params["domain"] in ("python_script", "script") ) # If this might start a script then disable the call timeout. # Otherwise use the normal service call limit. @@ -451,9 +451,7 @@ class _ScriptRun: service_task = self._hass.async_create_task( self._hass.services.async_call( - domain, - service_name, - service_data, + **params, blocking=True, context=self._context, limit=limit, diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 13dcd779b25..a13b866a418 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -14,6 +14,7 @@ from typing import ( Optional, Set, Tuple, + TypedDict, Union, cast, ) @@ -70,6 +71,15 @@ _LOGGER = logging.getLogger(__name__) SERVICE_DESCRIPTION_CACHE = "service_description_cache" +class ServiceParams(TypedDict): + """Type for service call parameters.""" + + domain: str + service: str + service_data: Dict[str, Any] + target: Optional[Dict] + + @dataclasses.dataclass class SelectedEntities: """Class to hold the selected entities.""" @@ -136,7 +146,7 @@ async def async_call_from_config( raise _LOGGER.error(ex) else: - await hass.services.async_call(*params, blocking, context) + await hass.services.async_call(**params, blocking=blocking, context=context) @ha.callback @@ -146,7 +156,7 @@ def async_prepare_call_from_config( config: ConfigType, variables: TemplateVarsType = None, validate_config: bool = False, -) -> Tuple[str, str, Dict[str, Any]]: +) -> ServiceParams: """Prepare to call a service based on a config hash.""" if validate_config: try: @@ -177,10 +187,9 @@ def async_prepare_call_from_config( domain, service = domain_service.split(".", 1) - service_data = {} + target = config.get(CONF_TARGET) - if CONF_TARGET in config: - service_data.update(config[CONF_TARGET]) + service_data = {} for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]: if conf not in config: @@ -192,9 +201,17 @@ def async_prepare_call_from_config( raise HomeAssistantError(f"Error rendering data template: {ex}") from ex if CONF_SERVICE_ENTITY_ID in config: - service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] + if target: + target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] + else: + target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]} - return domain, service, service_data + return { + "domain": domain, + "service": service, + "service_data": service_data, + "target": target, + } @bind_hass @@ -431,6 +448,7 @@ async def async_get_all_descriptions( description = descriptions_cache[cache_key] = { "description": yaml_description.get("description", ""), + "target": yaml_description.get("target"), "fields": yaml_description.get("fields", {}), } diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index a7aa17db6d3..1f7abc42c4e 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -52,6 +52,47 @@ async def test_call_service(hass, websocket_client): assert call.data == {"hello": "world"} +async def test_call_service_target(hass, websocket_client): + """Test call service command with target.""" + calls = [] + + @callback + def service_call(call): + calls.append(call) + + hass.services.async_register("domain_test", "test_service", service_call) + + await websocket_client.send_json( + { + "id": 5, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {"hello": "world"}, + "target": { + "entity_id": ["entity.one", "entity.two"], + "device_id": "deviceid", + }, + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + assert len(calls) == 1 + call = calls[0] + + assert call.domain == "domain_test" + assert call.service == "test_service" + assert call.data == { + "hello": "world", + "entity_id": ["entity.one", "entity.two"], + "device_id": ["deviceid"], + } + + async def test_call_service_not_found(hass, websocket_client): """Test call service command.""" await websocket_client.send_json(