From c4284c07b6d32c2445713f723518411735b140b9 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Fri, 16 Jun 2023 19:59:44 -0700 Subject: [PATCH] Allow scripts to capture service response data in variables (#94757) * Allow scripts service actions to save return values * Simplify script service response data * Rename result_variable to response_variable based on feedback --- homeassistant/const.py | 1 + homeassistant/helpers/config_validation.py | 2 + homeassistant/helpers/script.py | 23 ++++--- tests/helpers/test_script.py | 75 ++++++++++++++++++++++ 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index 5d4b0c2b515..94c932b1fd1 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -222,6 +222,7 @@ CONF_REPEAT: Final = "repeat" CONF_RESOURCE: Final = "resource" CONF_RESOURCES: Final = "resources" CONF_RESOURCE_TEMPLATE: Final = "resource_template" +CONF_RESPONSE_VARIABLE: Final = "response_variable" CONF_RGB: Final = "rgb" CONF_ROOM: Final = "room" CONF_SCAN_INTERVAL: Final = "scan_interval" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 27e4bc2c41f..db6a2fc5a8d 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -59,6 +59,7 @@ from homeassistant.const import ( CONF_PARALLEL, CONF_PLATFORM, CONF_REPEAT, + CONF_RESPONSE_VARIABLE, CONF_SCAN_INTERVAL, CONF_SCENE, CONF_SEQUENCE, @@ -1265,6 +1266,7 @@ SERVICE_SCHEMA = vol.All( ), vol.Optional(CONF_ENTITY_ID): comp_entity_ids, vol.Optional(CONF_TARGET): vol.Any(TARGET_SERVICE_FIELDS, dynamic_template), + vol.Optional(CONF_RESPONSE_VARIABLE): str, # The frontend stores data here. Don't use in core. vol.Remove("metadata"): dict, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 3bbc4ddd4ea..b876affb9e6 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -11,7 +11,7 @@ from functools import partial import itertools import logging from types import MappingProxyType -from typing import Any, TypedDict, cast +from typing import Any, TypedDict, TypeVar, cast import async_timeout import voluptuous as vol @@ -46,6 +46,7 @@ from homeassistant.const import ( CONF_MODE, CONF_PARALLEL, CONF_REPEAT, + CONF_RESPONSE_VARIABLE, CONF_SCENE, CONF_SEQUENCE, CONF_SERVICE, @@ -99,6 +100,8 @@ from .typing import ConfigType # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs +_T = TypeVar("_T") + SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_QUEUED = "queued" SCRIPT_MODE_RESTART = "restart" @@ -617,7 +620,7 @@ class _ScriptRun: task.cancel() unsub() - async def _async_run_long_action(self, long_task: asyncio.Task) -> None: + async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None: """Run a long task while monitoring for stop request.""" async def async_cancel_long_task() -> None: @@ -645,10 +648,10 @@ class _ScriptRun: raise asyncio.CancelledError if long_task.done(): # Propagate any exceptions that occurred. - long_task.result() - else: - # Stopped before long task completed, so cancel it. - await async_cancel_long_task() + return long_task.result() + # Stopped before long task completed, so cancel it. + await async_cancel_long_task() + return None async def _async_call_service_step(self): """Call the service specified in the action.""" @@ -663,16 +666,20 @@ class _ScriptRun: and params[CONF_SERVICE] == "trigger" or params[CONF_DOMAIN] in ("python_script", "script") ) + response_variable = self._action.get(CONF_RESPONSE_VARIABLE) trace_set_result(params=params, running_script=running_script) - await self._async_run_long_action( + response_data = await self._async_run_long_action( self._hass.async_create_task( self._hass.services.async_call( **params, blocking=True, context=self._context, + return_values=(response_variable is not None), ) - ) + ), ) + if response_variable: + self._variables[response_variable] = response_data async def _async_device_step(self): """Perform the device automation specified in the action.""" diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index de16dcac053..0868bb5a0cc 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -27,6 +27,7 @@ from homeassistant.core import ( CoreState, HomeAssistant, ServiceCall, + ServiceResult, callback, ) from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound @@ -329,6 +330,80 @@ async def test_calling_service_template(hass: HomeAssistant) -> None: ) +async def test_calling_service_return_values( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test the calling of a service with return values.""" + context = Context() + + def mock_service(call: ServiceCall) -> ServiceResult: + """Mock service call.""" + if call.return_values: + return {"data": "value-12345"} + return None + + hass.services.async_register("test", "script", mock_service) + sequence = cv.SCRIPT_SCHEMA( + [ + { + "alias": "service step1", + "service": "test.script", + # Store the result of the service call as a variable + "response_variable": "my_response", + }, + { + "alias": "service step2", + "service": "test.script", + "data_template": { + # Result of previous service call + "key": "{{ my_response.data }}" + }, + }, + ] + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=context) + await hass.async_block_till_done() + + assert "Executing step service step1" in caplog.text + assert "Executing step service step2" in caplog.text + + assert_action_trace( + { + "0": [ + { + "result": { + "params": { + "domain": "test", + "service": "script", + "service_data": {}, + "target": {}, + }, + "running_script": False, + } + } + ], + "1": [ + { + "result": { + "params": { + "domain": "test", + "service": "script", + "service_data": {"key": "value-12345"}, + "target": {}, + }, + "running_script": False, + }, + "variables": { + "my_response": {"data": "value-12345"}, + }, + } + ], + } + ) + + async def test_data_template_with_templated_key(hass: HomeAssistant) -> None: """Test the calling of a service with a data_template with a templated key.""" context = Context()