From d704d4f85357dcd511d67d7f822baaf72eb20480 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Wed, 13 Apr 2022 22:07:44 +0200 Subject: [PATCH] Add parallel automation/script actions (#69903) --- homeassistant/const.py | 1 + homeassistant/helpers/config_validation.py | 32 ++++ homeassistant/helpers/script.py | 76 +++++++++ tests/helpers/test_script.py | 173 +++++++++++++++++++++ 4 files changed, 282 insertions(+) diff --git a/homeassistant/const.py b/homeassistant/const.py index 6f0ba1d6ec6..6897a9c9b9f 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -189,6 +189,7 @@ CONF_NAME: Final = "name" CONF_OFFSET: Final = "offset" CONF_OPTIMISTIC: Final = "optimistic" CONF_PACKAGES: Final = "packages" +CONF_PARALLEL: Final = "parallel" CONF_PARAMS: Final = "params" CONF_PASSWORD: Final = "password" CONF_PATH: Final = "path" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index ea8264dad78..3ad178483fd 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -53,6 +53,7 @@ from homeassistant.const import ( CONF_ID, CONF_IF, CONF_MATCH, + CONF_PARALLEL, CONF_PLATFORM, CONF_REPEAT, CONF_SCAN_INTERVAL, @@ -1455,6 +1456,32 @@ _SCRIPT_ERROR_SCHEMA = vol.Schema( } ) + +_SCRIPT_PARALLEL_SEQUENCE = vol.Schema( + { + **SCRIPT_ACTION_BASE_SCHEMA, + vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA, + } +) + +_parallel_sequence_action = vol.All( + # Wrap a shorthand sequences in a parallel action + SCRIPT_SCHEMA, + lambda config: { + CONF_SEQUENCE: config, + }, +) + +_SCRIPT_PARALLEL_SCHEMA = vol.Schema( + { + **SCRIPT_ACTION_BASE_SCHEMA, + vol.Required(CONF_PARALLEL): vol.All( + ensure_list, [vol.Any(_SCRIPT_PARALLEL_SEQUENCE, _parallel_sequence_action)] + ), + } +) + + SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_CHECK_CONDITION = "condition" @@ -1469,6 +1496,7 @@ SCRIPT_ACTION_VARIABLES = "variables" SCRIPT_ACTION_STOP = "stop" SCRIPT_ACTION_ERROR = "error" SCRIPT_ACTION_IF = "if" +SCRIPT_ACTION_PARALLEL = "parallel" def determine_script_action(action: dict[str, Any]) -> str: @@ -1515,6 +1543,9 @@ def determine_script_action(action: dict[str, Any]) -> str: if CONF_ERROR in action: return SCRIPT_ACTION_ERROR + if CONF_PARALLEL in action: + return SCRIPT_ACTION_PARALLEL + raise ValueError("Unable to determine action") @@ -1533,6 +1564,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA, SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA, SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA, + SCRIPT_ACTION_PARALLEL: _SCRIPT_PARALLEL_SCHEMA, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index caff3d19f4b..f723b351b9c 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -5,6 +5,7 @@ import asyncio from collections.abc import Callable, Sequence from contextlib import asynccontextmanager, suppress from contextvars import ContextVar +from copy import copy from datetime import datetime, timedelta from functools import partial import itertools @@ -40,6 +41,7 @@ from homeassistant.const import ( CONF_EVENT_DATA_TEMPLATE, CONF_IF, CONF_MODE, + CONF_PARALLEL, CONF_REPEAT, CONF_SCENE, CONF_SEQUENCE, @@ -79,6 +81,7 @@ from .trace import ( trace_id_get, trace_path, trace_path_get, + trace_path_stack_cv, trace_set_result, trace_stack_cv, trace_stack_pop, @@ -307,6 +310,13 @@ async def async_validate_action_config( config[CONF_ELSE] = await async_validate_actions_config( hass, config[CONF_ELSE] ) + + elif action_type == cv.SCRIPT_ACTION_PARALLEL: + for parallel_conf in config[CONF_PARALLEL]: + parallel_conf[CONF_SEQUENCE] = await async_validate_actions_config( + hass, parallel_conf[CONF_SEQUENCE] + ) + else: raise ValueError(f"No validation for {action_type}") @@ -896,6 +906,26 @@ class _ScriptRun: trace_set_result(error=error) raise _AbortScript(error) + @async_trace_path("parallel") + async def _async_parallel_step(self) -> None: + """Run a sequence in parallel.""" + # pylint: disable=protected-access + scripts = await self._script._async_get_parallel_scripts(self._step) + + async def async_run_with_trace(idx: int, script: Script) -> None: + """Run a script with a trace path.""" + trace_path_stack_cv.set(copy(trace_path_stack_cv.get())) + with trace_path([str(idx), "sequence"]): + await self._async_run_script(script) + + results = await asyncio.gather( + *(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)), + return_exceptions=True, + ) + for result in results: + if isinstance(result, Exception): + raise result + async def _async_run_script(self, script: Script) -> None: """Execute a script.""" await self._async_run_long_action( @@ -1075,6 +1105,7 @@ class Script: self._repeat_script: dict[int, Script] = {} self._choose_data: dict[int, _ChooseData] = {} self._if_data: dict[int, _IfData] = {} + self._parallel_scripts: dict[int, list[Script]] = {} self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None self._referenced_areas: set[str] | None = None @@ -1109,6 +1140,9 @@ class Script: self._set_logger(logger) for script in self._repeat_script.values(): script.update_logger(self._logger) + for parallel_scripts in self._parallel_scripts.values(): + for parallel_script in parallel_scripts: + parallel_script.update_logger(self._logger) for choose_data in self._choose_data.values(): for _, script in choose_data["choices"]: script.update_logger(self._logger) @@ -1178,6 +1212,10 @@ class Script: if CONF_ELSE in step: Script._find_referenced_areas(referenced, step[CONF_ELSE]) + elif action == cv.SCRIPT_ACTION_PARALLEL: + for script in step[CONF_PARALLEL]: + Script._find_referenced_areas(referenced, script[CONF_SEQUENCE]) + @property def referenced_devices(self): """Return a set of referenced devices.""" @@ -1222,6 +1260,10 @@ class Script: if CONF_ELSE in step: Script._find_referenced_devices(referenced, step[CONF_ELSE]) + elif action == cv.SCRIPT_ACTION_PARALLEL: + for script in step[CONF_PARALLEL]: + Script._find_referenced_devices(referenced, script[CONF_SEQUENCE]) + @property def referenced_entities(self): """Return a set of referenced entities.""" @@ -1267,6 +1309,10 @@ class Script: if CONF_ELSE in step: Script._find_referenced_entities(referenced, step[CONF_ELSE]) + elif action == cv.SCRIPT_ACTION_PARALLEL: + for script in step[CONF_PARALLEL]: + Script._find_referenced_entities(referenced, script[CONF_SEQUENCE]) + def run( self, variables: _VarsType | None = None, context: Context | None = None ) -> None: @@ -1530,6 +1576,36 @@ class Script: self._if_data[step] = if_data return if_data + async def _async_prep_parallel_scripts(self, step: int) -> list[Script]: + action = self.sequence[step] + step_name = action.get(CONF_ALIAS, f"Parallel action at step {step+1}") + parallel_scripts: list[Script] = [] + for idx, parallel_script in enumerate(action[CONF_PARALLEL], start=1): + parallel_name = parallel_script.get(CONF_ALIAS, f"parallel {idx}") + parallel_script = Script( + self._hass, + parallel_script[CONF_SEQUENCE], + f"{self.name}: {step_name}: {parallel_name}", + self.domain, + running_description=self.running_description, + script_mode=SCRIPT_MODE_PARALLEL, + max_runs=self.max_runs, + logger=self._logger, + top_level=False, + ) + parallel_script.change_listener = partial( + self._chain_change_listener, parallel_script + ) + parallel_scripts.append(parallel_script) + + return parallel_scripts + + async def _async_get_parallel_scripts(self, step: int) -> list[Script]: + if not (parallel_scripts := self._parallel_scripts.get(step)): + parallel_scripts = await self._async_prep_parallel_scripts(step) + self._parallel_scripts[step] = parallel_scripts + return parallel_scripts + def _log( self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any ) -> None: diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5a6849df3ef..1fbf2b2625e 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -2682,6 +2682,148 @@ async def test_if_condition_validation( ) +async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: + """Test parallel action.""" + events = async_capture_events(hass, "test_event") + hass.states.async_set("switch.trigger", "off") + + sequence = cv.SCRIPT_SCHEMA( + { + "parallel": [ + { + "alias": "Sequential group", + "sequence": [ + { + "alias": "Waiting for trigger", + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.trigger", + "to": "on", + }, + }, + { + "event": "test_event", + "event_data": { + "hello": "from action 1", + "what": "{{ what }}", + }, + }, + ], + }, + { + "alias": "Don't wait at all", + "event": "test_event", + "event_data": {"hello": "from action 2", "what": "{{ what }}"}, + }, + ] + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + wait_started_flag = async_watch_for_action(script_obj, "Waiting for trigger") + hass.async_create_task( + script_obj.async_run(MappingProxyType({"what": "world"}), Context()) + ) + await asyncio.wait_for(wait_started_flag.wait(), 1) + + assert script_obj.is_running + + hass.states.async_set("switch.trigger", "on") + await hass.async_block_till_done() + + assert len(events) == 2 + assert events[0].data["hello"] == "from action 2" + assert events[0].data["what"] == "world" + assert events[1].data["hello"] == "from action 1" + assert events[1].data["what"] == "world" + + assert ( + "Test Name: Parallel action at step 1: Sequential group: Executing step Waiting for trigger" + in caplog.text + ) + assert ( + "Parallel action at step 1: parallel 2: Executing step Don't wait at all" + in caplog.text + ) + + expected_trace = { + "0": [{"result": {}}], + "0/parallel/0/sequence/0": [ + { + "result": { + "wait": { + "remaining": None, + "trigger": { + "entity_id": "switch.trigger", + "description": "state of switch.trigger", + }, + } + } + } + ], + "0/parallel/1/sequence/0": [ + { + "variables": {"wait": {"remaining": None}}, + "result": { + "event": "test_event", + "event_data": {"hello": "from action 2", "what": "world"}, + }, + } + ], + "0/parallel/0/sequence/1": [ + { + "variables": {"wait": {"remaining": None}}, + "result": { + "event": "test_event", + "event_data": {"hello": "from action 1", "what": "world"}, + }, + } + ], + } + assert_action_trace(expected_trace) + + +async def test_parallel_error( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test parallel action failure handling.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "parallel": [ + {"service": "epic.failure"}, + ] + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + with pytest.raises(exceptions.ServiceNotFound): + await script_obj.async_run(context=Context()) + assert len(events) == 0 + + expected_trace = { + "0": [{"error_type": ServiceNotFound, "result": {}}], + "0/parallel/0/sequence/0": [ + { + "error_type": ServiceNotFound, + "result": { + "limit": 10, + "params": { + "domain": "epic", + "service": "failure", + "service_data": {}, + "target": {}, + }, + "running_script": False, + }, + } + ], + } + assert_action_trace(expected_trace, expected_script_execution="error") + + async def test_last_triggered(hass): """Test the last_triggered.""" event = "test_event" @@ -2881,6 +3023,14 @@ async def test_referenced_areas(hass): } ], }, + { + "parallel": [ + { + "service": "test.script", + "data": {"area_id": "area_parallel"}, + } + ], + }, ] ), "Test Name", @@ -2896,6 +3046,7 @@ async def test_referenced_areas(hass): "area_service_not_list", "area_if_then", "area_if_else", + "area_parallel", # 'area_service_template', # no area extraction from template } # Test we cache results. @@ -2988,6 +3139,14 @@ async def test_referenced_entities(hass): } ], }, + { + "parallel": [ + { + "service": "test.script", + "data": {"entity_id": "light.parallel"}, + } + ], + }, ] ), "Test Name", @@ -3006,6 +3165,7 @@ async def test_referenced_entities(hass): "light.service_not_list", "light.if_then", "light.if_else", + "light.parallel", # "light.service_template", # no entity extraction from template "scene.hello", "sensor.condition", @@ -3093,6 +3253,14 @@ async def test_referenced_devices(hass): } ], }, + { + "parallel": [ + { + "service": "test.script", + "target": {"device_id": "parallel-device"}, + } + ], + }, ] ), "Test Name", @@ -3113,6 +3281,7 @@ async def test_referenced_devices(hass): "target-string-id", "if-then", "if-else", + "parallel-device", } # Test we cache results. assert script_obj.referenced_devices is script_obj.referenced_devices @@ -3744,6 +3913,9 @@ async def test_validate_action_config(hass): "then": [templated_device_action("if_then_event")], "else": [templated_device_action("if_else_event")], }, + cv.SCRIPT_ACTION_PARALLEL: { + "parallel": [templated_device_action("parallel_event")], + }, } expected_templates = { cv.SCRIPT_ACTION_CHECK_CONDITION: None, @@ -3752,6 +3924,7 @@ async def test_validate_action_config(hass): cv.SCRIPT_ACTION_CHOOSE: [["choose", 0, "sequence", 0], ["default", 0]], cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: None, cv.SCRIPT_ACTION_IF: None, + cv.SCRIPT_ACTION_PARALLEL: None, } for key in cv.ACTION_TYPE_SCHEMAS: