From 2f87da8aa935ee085dc4b014778171a924316f12 Mon Sep 17 00:00:00 2001 From: Phil Bruckner Date: Fri, 24 Jul 2020 01:11:21 -0500 Subject: [PATCH] Fix script repeat variable lifetime (#38124) --- homeassistant/helpers/script.py | 44 +++++++----- tests/helpers/test_script.py | 116 ++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 18 deletions(-) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 475beb02690..bc6e4bdbd36 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -140,7 +140,7 @@ class _ScriptRun: ) -> None: self._hass = hass self._script = script - self._variables = variables + self._variables = variables or {} self._context = context self._log_exceptions = log_exceptions self._step = -1 @@ -431,22 +431,23 @@ class _ScriptRun: async def _async_repeat_step(self): """Repeat a sequence.""" - description = self._action.get(CONF_ALIAS, "sequence") repeat = self._action[CONF_REPEAT] - async def async_run_sequence(iteration, extra_msg="", extra_vars=None): + saved_repeat_vars = self._variables.get("repeat") + + def set_repeat_var(iteration, count=None): + repeat_vars = {"first": iteration == 1, "index": iteration} + if count: + repeat_vars["last"] = iteration == count + self._variables["repeat"] = repeat_vars + + # pylint: disable=protected-access + script = self._script._get_repeat_script(self._step) + + async def async_run_sequence(iteration, extra_msg=""): self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) - repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}} - if extra_vars: - repeat_vars["repeat"].update(extra_vars) - # pylint: disable=protected-access - await self._async_run_script( - self._script._get_repeat_script(self._step), - # Add repeat to variables. Override if it already exists in case of - # nested calls. - {**(self._variables or {}), **repeat_vars}, - ) + await self._async_run_script(script) if CONF_COUNT in repeat: count = repeat[CONF_COUNT] @@ -461,10 +462,10 @@ class _ScriptRun: level=logging.ERROR, ) raise _StopScript + extra_msg = f" of {count}" for iteration in range(1, count + 1): - await async_run_sequence( - iteration, f" of {count}", {"last": iteration == count} - ) + set_repeat_var(iteration, count) + await async_run_sequence(iteration, extra_msg) if self._stop.is_set(): break @@ -473,6 +474,7 @@ class _ScriptRun: await self._async_get_condition(config) for config in repeat[CONF_WHILE] ] for iteration in itertools.count(1): + set_repeat_var(iteration) if self._stop.is_set() or not all( cond(self._hass, self._variables) for cond in conditions ): @@ -484,12 +486,18 @@ class _ScriptRun: await self._async_get_condition(config) for config in repeat[CONF_UNTIL] ] for iteration in itertools.count(1): + set_repeat_var(iteration) await async_run_sequence(iteration) if self._stop.is_set() or all( cond(self._hass, self._variables) for cond in conditions ): break + if saved_repeat_vars: + self._variables["repeat"] = saved_repeat_vars + else: + del self._variables["repeat"] + async def _async_choose_step(self): """Choose a sequence.""" # pylint: disable=protected-access @@ -503,11 +511,11 @@ class _ScriptRun: if choose_data["default"]: await self._async_run_script(choose_data["default"]) - async def _async_run_script(self, script, variables=None): + async def _async_run_script(self, script): """Execute a script.""" await self._async_run_long_action( self._hass.async_create_task( - script.async_run(variables or self._variables, self._context) + script.async_run(self._variables, self._context) ) ) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 4001b6a3215..ab30b0457c5 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -854,6 +854,122 @@ async def test_repeat_conditional(hass, condition): assert event.data.get("index") == str(index + 1) +@pytest.mark.parametrize("condition", ["while", "until"]) +async def test_repeat_var_in_condition(hass, condition): + """Test repeat action w/ while option.""" + event = "test_event" + events = async_capture_events(hass, event) + + sequence = {"repeat": {"sequence": {"event": event}}} + if condition == "while": + sequence["repeat"]["while"] = { + "condition": "template", + "value_template": "{{ repeat.index <= 2 }}", + } + else: + sequence["repeat"]["until"] = { + "condition": "template", + "value_template": "{{ repeat.index == 2 }}", + } + script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence)) + + with mock.patch( + "homeassistant.helpers.condition._LOGGER.error", + side_effect=AssertionError("Template Error"), + ): + await script_obj.async_run() + + assert len(events) == 2 + + +async def test_repeat_nested(hass): + """Test nested repeats.""" + event = "test_event" + events = async_capture_events(hass, event) + + sequence = cv.SCRIPT_SCHEMA( + [ + { + "event": event, + "event_data_template": { + "repeat": "{{ None if repeat is not defined else repeat }}" + }, + }, + { + "repeat": { + "count": 2, + "sequence": [ + { + "event": event, + "event_data_template": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + "last": "{{ repeat.last }}", + }, + }, + { + "repeat": { + "count": 2, + "sequence": { + "event": event, + "event_data_template": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + "last": "{{ repeat.last }}", + }, + }, + } + }, + { + "event": event, + "event_data_template": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + "last": "{{ repeat.last }}", + }, + }, + ], + } + }, + { + "event": event, + "event_data_template": { + "repeat": "{{ None if repeat is not defined else repeat }}" + }, + }, + ] + ) + script_obj = script.Script(hass, sequence, "test script") + + with mock.patch( + "homeassistant.helpers.condition._LOGGER.error", + side_effect=AssertionError("Template Error"), + ): + await script_obj.async_run() + + assert len(events) == 10 + assert events[0].data == {"repeat": "None"} + assert events[-1].data == {"repeat": "None"} + for index, result in enumerate( + ( + ("True", "1", "False"), + ("True", "1", "False"), + ("False", "2", "True"), + ("True", "1", "False"), + ("False", "2", "True"), + ("True", "1", "False"), + ("False", "2", "True"), + ("False", "2", "True"), + ), + 1, + ): + assert events[index].data == { + "first": result[0], + "index": result[1], + "last": result[2], + } + + @pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")]) async def test_choose(hass, var, result): """Test choose action."""