Allow scripts to capture service response data in variables (#94757)

* Allow scripts service actions to save return values

* Simplify script service response data

* Rename result_variable to response_variable based on feedback
This commit is contained in:
Allen Porter 2023-06-16 19:59:44 -07:00 committed by GitHub
parent 4f669b326f
commit c4284c07b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 8 deletions

View File

@ -222,6 +222,7 @@ CONF_REPEAT: Final = "repeat"
CONF_RESOURCE: Final = "resource" CONF_RESOURCE: Final = "resource"
CONF_RESOURCES: Final = "resources" CONF_RESOURCES: Final = "resources"
CONF_RESOURCE_TEMPLATE: Final = "resource_template" CONF_RESOURCE_TEMPLATE: Final = "resource_template"
CONF_RESPONSE_VARIABLE: Final = "response_variable"
CONF_RGB: Final = "rgb" CONF_RGB: Final = "rgb"
CONF_ROOM: Final = "room" CONF_ROOM: Final = "room"
CONF_SCAN_INTERVAL: Final = "scan_interval" CONF_SCAN_INTERVAL: Final = "scan_interval"

View File

@ -59,6 +59,7 @@ from homeassistant.const import (
CONF_PARALLEL, CONF_PARALLEL,
CONF_PLATFORM, CONF_PLATFORM,
CONF_REPEAT, CONF_REPEAT,
CONF_RESPONSE_VARIABLE,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
@ -1265,6 +1266,7 @@ SERVICE_SCHEMA = vol.All(
), ),
vol.Optional(CONF_ENTITY_ID): comp_entity_ids, vol.Optional(CONF_ENTITY_ID): comp_entity_ids,
vol.Optional(CONF_TARGET): vol.Any(TARGET_SERVICE_FIELDS, dynamic_template), vol.Optional(CONF_TARGET): vol.Any(TARGET_SERVICE_FIELDS, dynamic_template),
vol.Optional(CONF_RESPONSE_VARIABLE): str,
# The frontend stores data here. Don't use in core. # The frontend stores data here. Don't use in core.
vol.Remove("metadata"): dict, vol.Remove("metadata"): dict,
} }

View File

@ -11,7 +11,7 @@ from functools import partial
import itertools import itertools
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, TypedDict, cast from typing import Any, TypedDict, TypeVar, cast
import async_timeout import async_timeout
import voluptuous as vol import voluptuous as vol
@ -46,6 +46,7 @@ from homeassistant.const import (
CONF_MODE, CONF_MODE,
CONF_PARALLEL, CONF_PARALLEL,
CONF_REPEAT, CONF_REPEAT,
CONF_RESPONSE_VARIABLE,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_SERVICE, CONF_SERVICE,
@ -99,6 +100,8 @@ from .typing import ConfigType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_T = TypeVar("_T")
SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUED = "queued" SCRIPT_MODE_QUEUED = "queued"
SCRIPT_MODE_RESTART = "restart" SCRIPT_MODE_RESTART = "restart"
@ -617,7 +620,7 @@ class _ScriptRun:
task.cancel() task.cancel()
unsub() unsub()
async def _async_run_long_action(self, long_task: asyncio.Task) -> None: async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
"""Run a long task while monitoring for stop request.""" """Run a long task while monitoring for stop request."""
async def async_cancel_long_task() -> None: async def async_cancel_long_task() -> None:
@ -645,10 +648,10 @@ class _ScriptRun:
raise asyncio.CancelledError raise asyncio.CancelledError
if long_task.done(): if long_task.done():
# Propagate any exceptions that occurred. # Propagate any exceptions that occurred.
long_task.result() return long_task.result()
else: # Stopped before long task completed, so cancel it.
# Stopped before long task completed, so cancel it. await async_cancel_long_task()
await async_cancel_long_task() return None
async def _async_call_service_step(self): async def _async_call_service_step(self):
"""Call the service specified in the action.""" """Call the service specified in the action."""
@ -663,16 +666,20 @@ class _ScriptRun:
and params[CONF_SERVICE] == "trigger" and params[CONF_SERVICE] == "trigger"
or params[CONF_DOMAIN] in ("python_script", "script") or params[CONF_DOMAIN] in ("python_script", "script")
) )
response_variable = self._action.get(CONF_RESPONSE_VARIABLE)
trace_set_result(params=params, running_script=running_script) trace_set_result(params=params, running_script=running_script)
await self._async_run_long_action( response_data = await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task(
self._hass.services.async_call( self._hass.services.async_call(
**params, **params,
blocking=True, blocking=True,
context=self._context, context=self._context,
return_values=(response_variable is not None),
) )
) ),
) )
if response_variable:
self._variables[response_variable] = response_data
async def _async_device_step(self): async def _async_device_step(self):
"""Perform the device automation specified in the action.""" """Perform the device automation specified in the action."""

View File

@ -27,6 +27,7 @@ from homeassistant.core import (
CoreState, CoreState,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResult,
callback, callback,
) )
from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound
@ -329,6 +330,80 @@ async def test_calling_service_template(hass: HomeAssistant) -> None:
) )
async def test_calling_service_return_values(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test the calling of a service with return values."""
context = Context()
def mock_service(call: ServiceCall) -> ServiceResult:
"""Mock service call."""
if call.return_values:
return {"data": "value-12345"}
return None
hass.services.async_register("test", "script", mock_service)
sequence = cv.SCRIPT_SCHEMA(
[
{
"alias": "service step1",
"service": "test.script",
# Store the result of the service call as a variable
"response_variable": "my_response",
},
{
"alias": "service step2",
"service": "test.script",
"data_template": {
# Result of previous service call
"key": "{{ my_response.data }}"
},
},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
await script_obj.async_run(context=context)
await hass.async_block_till_done()
assert "Executing step service step1" in caplog.text
assert "Executing step service step2" in caplog.text
assert_action_trace(
{
"0": [
{
"result": {
"params": {
"domain": "test",
"service": "script",
"service_data": {},
"target": {},
},
"running_script": False,
}
}
],
"1": [
{
"result": {
"params": {
"domain": "test",
"service": "script",
"service_data": {"key": "value-12345"},
"target": {},
},
"running_script": False,
},
"variables": {
"my_response": {"data": "value-12345"},
},
}
],
}
)
async def test_data_template_with_templated_key(hass: HomeAssistant) -> None: async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
"""Test the calling of a service with a data_template with a templated key.""" """Test the calling of a service with a data_template with a templated key."""
context = Context() context = Context()