diff --git a/homeassistant/const.py b/homeassistant/const.py index 75d0aa9b6bd..b2654da367a 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -158,6 +158,7 @@ CONF_EXTERNAL_URL: Final = "external_url" CONF_FILENAME: Final = "filename" CONF_FILE_PATH: Final = "file_path" CONF_FOR: Final = "for" +CONF_FOR_EACH: Final = "for_each" CONF_FORCE_UPDATE: Final = "force_update" CONF_FRIENDLY_NAME: Final = "friendly_name" CONF_FRIENDLY_NAME_TEMPLATE: Final = "friendly_name_template" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index c681b5a8284..e423be718ea 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -52,6 +52,7 @@ from homeassistant.const import ( CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, CONF_FOR, + CONF_FOR_EACH, CONF_ID, CONF_IF, CONF_MATCH, @@ -1395,6 +1396,9 @@ _SCRIPT_REPEAT_SCHEMA = vol.Schema( vol.Required(CONF_REPEAT): vol.All( { vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template), + vol.Exclusive(CONF_FOR_EACH, "repeat"): vol.Any( + dynamic_template, vol.All(list, template_complex) + ), vol.Exclusive(CONF_WHILE, "repeat"): vol.All( ensure_list, [CONDITION_SCHEMA] ), @@ -1403,7 +1407,7 @@ _SCRIPT_REPEAT_SCHEMA = vol.Schema( ), vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA, }, - has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL), + has_at_least_one_key(CONF_COUNT, CONF_FOR_EACH, CONF_WHILE, CONF_UNTIL), ), } ) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index e314cdc5c9c..d5aa76a98a1 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -41,6 +41,7 @@ from homeassistant.const import ( CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, + CONF_FOR_EACH, CONF_IF, CONF_MODE, CONF_PARALLEL, @@ -744,17 +745,21 @@ class _ScriptRun: return result @async_trace_path("repeat") - async def _async_repeat_step(self): + async def _async_repeat_step(self): # noqa: C901 """Repeat a sequence.""" description = self._action.get(CONF_ALIAS, "sequence") repeat = self._action[CONF_REPEAT] saved_repeat_vars = self._variables.get("repeat") - def set_repeat_var(iteration, count=None): + def set_repeat_var( + iteration: int, count: int | None = None, item: Any = None + ) -> None: repeat_vars = {"first": iteration == 1, "index": iteration} if count: repeat_vars["last"] = iteration == count + if item is not None: + repeat_vars["item"] = item self._variables["repeat"] = repeat_vars # pylint: disable=protected-access @@ -785,6 +790,35 @@ class _ScriptRun: if self._stop.is_set(): break + elif CONF_FOR_EACH in repeat: + try: + items = template.render_complex(repeat[CONF_FOR_EACH], self._variables) + except (exceptions.TemplateError, ValueError) as ex: + self._log( + "Error rendering %s repeat for each items template: %s", + self._script.name, + ex, + level=logging.ERROR, + ) + raise _AbortScript from ex + + if not isinstance(items, list): + self._log( + "Repeat 'for_each' must be a list of items in %s, got: %s", + self._script.name, + items, + level=logging.ERROR, + ) + raise _AbortScript("Repeat 'for_each' must be a list of items") + + count = len(items) + for iteration, item in enumerate(items, 1): + set_repeat_var(iteration, count, item) + extra_msg = f" of {count} with item: {repr(item)}" + if self._stop.is_set(): + break + await async_run_sequence(iteration, extra_msg) + elif CONF_WHILE in repeat: conditions = [ await self._async_get_condition(config) for config in repeat[CONF_WHILE] diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 9a1f643bc06..b1c65cfe971 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1787,6 +1787,232 @@ async def test_repeat_count_0(hass, caplog): ) +async def test_repeat_for_each( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test repeat action using for each.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "alias": "For each!", + "repeat": { + "for_each": ["one", "two", "{{ 'thr' + 'ee' }}"], + "sequence": { + "event": "test_event", + "event_data": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + "last": "{{ repeat.last }}", + "item": "{{ repeat.item }}", + }, + }, + }, + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(events) == 3 + assert "Repeating For each!: Iteration 1 of 3 with item: 'one'" in caplog.text + assert "Repeating For each!: Iteration 2 of 3 with item: 'two'" in caplog.text + assert "Repeating For each!: Iteration 3 of 3 with item: 'three'" in caplog.text + + assert_action_trace( + { + "0": [{}], + "0/repeat/sequence/0": [ + { + "result": { + "event": "test_event", + "event_data": { + "first": True, + "index": 1, + "last": False, + "item": "one", + }, + }, + "variables": { + "repeat": { + "first": True, + "index": 1, + "last": False, + "item": "one", + } + }, + }, + { + "result": { + "event": "test_event", + "event_data": { + "first": False, + "index": 2, + "last": False, + "item": "two", + }, + }, + "variables": { + "repeat": { + "first": False, + "index": 2, + "last": False, + "item": "two", + } + }, + }, + { + "result": { + "event": "test_event", + "event_data": { + "first": False, + "index": 3, + "last": True, + "item": "three", + }, + }, + "variables": { + "repeat": { + "first": False, + "index": 3, + "last": True, + "item": "three", + } + }, + }, + ], + } + ) + + +async def test_repeat_for_each_template(hass: HomeAssistant) -> None: + """Test repeat action using for each template.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "alias": "", + "repeat": { + "for_each": ( + "{% set var = ['light.bulb_one', 'light.bulb_two'] %} {{ var }}" + ), + "sequence": { + "event": "test_event", + }, + }, + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(events) == 2 + + assert_action_trace( + { + "0": [{}], + "0/repeat/sequence/0": [ + { + "result": { + "event": "test_event", + "event_data": {}, + }, + "variables": { + "repeat": { + "first": True, + "index": 1, + "last": False, + "item": "light.bulb_one", + } + }, + }, + { + "result": { + "event": "test_event", + "event_data": {}, + }, + "variables": { + "repeat": { + "first": False, + "index": 2, + "last": True, + "item": "light.bulb_two", + } + }, + }, + ], + } + ) + + +async def test_repeat_for_each_non_list_template(hass: HomeAssistant) -> None: + """Test repeat action using for each with a template not resulting in a list.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "repeat": { + "for_each": "{{ 'Not a list' }}", + "sequence": { + "event": "test_event", + }, + }, + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(events) == 0 + + assert_action_trace( + { + "0": [ + { + "error_type": script._AbortScript, + } + ], + }, + expected_script_execution="aborted", + ) + + +async def test_repeat_for_each_invalid_template(hass: HomeAssistant, caplog) -> None: + """Test repeat action using for each with an invalid template.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "repeat": { + "for_each": "{{ Muhaha }}", + "sequence": { + "event": "test_event", + }, + }, + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert ( + "Test Name: Repeat 'for_each' must be a list of items in Test Name, got" + in caplog.text + ) + assert len(events) == 0 + + assert_action_trace( + { + "0": [{"error_type": script._AbortScript}], + }, + expected_script_execution="aborted", + ) + + @pytest.mark.parametrize("condition", ["while", "until"]) async def test_repeat_condition_warning(hass, caplog, condition): """Test warning on repeat conditions."""