diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 57e8f8e5ba2..030e5dacfd5 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -967,6 +967,15 @@ async def async_validate_condition_config( return config +async def async_validate_conditions_config( + hass: HomeAssistant, conditions: list[ConfigType | Template] +) -> list[ConfigType | Template]: + """Validate config.""" + return await asyncio.gather( + *(async_validate_condition_config(hass, cond) for cond in conditions) + ) + + @callback def async_extract_entities(config: ConfigType | Template) -> set[str]: """Extract entities from a condition.""" diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 933b44d9ec9..d4d37e1b4ac 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -270,6 +270,16 @@ async def async_validate_action_config( ) elif action_type == cv.SCRIPT_ACTION_REPEAT: + if CONF_UNTIL in config[CONF_REPEAT]: + conditions = await condition.async_validate_conditions_config( + hass, config[CONF_REPEAT][CONF_UNTIL] + ) + config[CONF_REPEAT][CONF_UNTIL] = conditions + if CONF_WHILE in config[CONF_REPEAT]: + conditions = await condition.async_validate_conditions_config( + hass, config[CONF_REPEAT][CONF_WHILE] + ) + config[CONF_REPEAT][CONF_WHILE] = conditions config[CONF_REPEAT][CONF_SEQUENCE] = await async_validate_actions_config( hass, config[CONF_REPEAT][CONF_SEQUENCE] ) @@ -281,6 +291,10 @@ async def async_validate_action_config( ) for choose_conf in config[CONF_CHOOSE]: + conditions = await condition.async_validate_conditions_config( + hass, choose_conf[CONF_CONDITIONS] + ) + choose_conf[CONF_CONDITIONS] = conditions choose_conf[CONF_SEQUENCE] = await async_validate_actions_config( hass, choose_conf[CONF_SEQUENCE] ) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 962fe4b1366..b0a93c85b2b 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -56,8 +56,11 @@ def compare_trigger_item(actual_trigger, expected_trigger): assert actual_trigger["description"] == expected_trigger["description"] -def compare_result_item(key, actual, expected): - """Compare an item in the result dict.""" +def compare_result_item(key, actual, expected, path): + """Compare an item in the result dict. + + Note: Unused variable 'path' is passed to get helpful errors from pytest. + """ if key == "wait" and (expected.get("trigger") is not None): assert "trigger" in actual expected_trigger = expected.pop("trigger") @@ -78,7 +81,7 @@ def assert_element(trace_element, expected_element, path): # The redundant set operation gives helpful errors from pytest assert not set(expected_result) - set(trace_element._result or {}) for result_key, result in expected_result.items(): - compare_result_item(result_key, trace_element._result[result_key], result) + compare_result_item(result_key, trace_element._result[result_key], result, path) assert trace_element._result[result_key] == result # Check for unexpected items in trace_element @@ -1819,6 +1822,126 @@ async def test_repeat_conditional(hass, condition, direct_template): assert event.data.get("index") == index + 1 +async def test_repeat_until_condition_validation(hass, caplog): + """Test if we can use conditions in repeat until conditions which validate late.""" + registry = er.async_get(hass) + entry = registry.async_get_or_create( + "test", "hue", "1234", suggested_object_id="entity" + ) + assert entry.entity_id == "test.entity" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + { + "repeat": { + "sequence": [ + {"event": event}, + ], + "until": [ + { + "condition": "state", + "entity_id": entry.id, + "state": "hello", + } + ], + } + }, + ] + ) + hass.states.async_set("test.entity", "hello") + sequence = await script.async_validate_actions_config(hass, sequence) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + caplog.clear() + assert len(events) == 1 + + assert_action_trace( + { + "0": [{"result": {}}], + "0/repeat/sequence/0": [ + { + "result": {"event": "test_event", "event_data": {}}, + "variables": {"repeat": {"first": True, "index": 1}}, + } + ], + "0/repeat": [ + { + "result": {"result": True}, + "variables": {"repeat": {"first": True, "index": 1}}, + } + ], + "0/repeat/until/0": [{"result": {"result": True}}], + "0/repeat/until/0/entity_id/0": [ + {"result": {"result": True, "state": "hello", "wanted_state": "hello"}} + ], + } + ) + + +async def test_repeat_while_condition_validation(hass, caplog): + """Test if we can use conditions in repeat while conditions which validate late.""" + registry = er.async_get(hass) + entry = registry.async_get_or_create( + "test", "hue", "1234", suggested_object_id="entity" + ) + assert entry.entity_id == "test.entity" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + { + "repeat": { + "sequence": [ + {"event": event}, + ], + "while": [ + { + "condition": "state", + "entity_id": entry.id, + "state": "hello", + } + ], + } + }, + ] + ) + hass.states.async_set("test.entity", "goodbye") + sequence = await script.async_validate_actions_config(hass, sequence) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + caplog.clear() + assert len(events) == 0 + + assert_action_trace( + { + "0": [{"result": {}}], + "0/repeat": [ + { + "result": {"result": False}, + "variables": {"repeat": {"first": True, "index": 1}}, + } + ], + "0/repeat/while/0": [{"result": {"result": False}}], + "0/repeat/while/0/entity_id/0": [ + { + "result": { + "result": False, + "state": "goodbye", + "wanted_state": "hello", + } + } + ], + } + ) + + @pytest.mark.parametrize("condition", ["while", "until"]) async def test_repeat_var_in_condition(hass, condition): """Test repeat action w/ while option.""" @@ -2182,6 +2305,88 @@ async def test_choose(hass, caplog, var, result): assert_action_trace(expected_trace) +async def test_choose_condition_validation(hass, caplog): + """Test if we can use conditions in choose actions which validate late.""" + registry = er.async_get(hass) + entry = registry.async_get_or_create( + "test", "hue", "1234", suggested_object_id="entity" + ) + assert entry.entity_id == "test.entity" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + { + "choose": [ + { + "alias": "choice one", + "conditions": { + "condition": "state", + "entity_id": entry.id, + "state": "hello", + }, + "sequence": { + "alias": "sequence one", + "event": event, + "event_data": {"choice": "first"}, + }, + }, + ] + }, + ] + ) + sequence = await script.async_validate_actions_config(hass, sequence) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + hass.states.async_set("test.entity", "hello") + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + caplog.clear() + assert len(events) == 2 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{"result": {"choice": 0}}], + "1/choose/0": [{"result": {"result": True}}], + "1/choose/0/conditions/0": [{"result": {"result": True}}], + "1/choose/0/conditions/0/entity_id/0": [ + {"result": {"result": True, "state": "hello", "wanted_state": "hello"}} + ], + "1/choose/0/sequence/0": [ + {"result": {"event": "test_event", "event_data": {"choice": "first"}}} + ], + } + ) + + hass.states.async_set("test.entity", "goodbye") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(events) == 3 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{"result": {}}], + "1/choose/0": [{"result": {"result": False}}], + "1/choose/0/conditions/0": [{"result": {"result": False}}], + "1/choose/0/conditions/0/entity_id/0": [ + { + "result": { + "result": False, + "state": "goodbye", + "wanted_state": "hello", + } + } + ], + }, + ) + + @pytest.mark.parametrize( "action", [ @@ -3132,7 +3337,9 @@ async def test_validate_action_config(hass): }, cv.SCRIPT_ACTION_FIRE_EVENT: {"event": "my_event"}, cv.SCRIPT_ACTION_CHECK_CONDITION: { - "condition": "{{ states.light.kitchen.state == 'on' }}" + "condition": "state", + "entity_id": "light.kitchen", + "state": "on", }, cv.SCRIPT_ACTION_DEVICE_AUTOMATION: templated_device_action("device"), cv.SCRIPT_ACTION_ACTIVATE_SCENE: {"scene": "scene.relax"}, @@ -3145,7 +3352,7 @@ async def test_validate_action_config(hass): cv.SCRIPT_ACTION_CHOOSE: { "choose": [ { - "condition": "{{ states.light.kitchen.state == 'on' }}", + "conditions": "{{ states.light.kitchen.state == 'on' }}", "sequence": [templated_device_action("choose_event")], } ], @@ -3182,8 +3389,9 @@ async def test_validate_action_config(hass): for action_type, config in configs.items(): assert cv.determine_script_action(config) == action_type try: + validated_config[action_type] = cv.ACTION_TYPE_SCHEMAS[action_type](config) validated_config[action_type] = await script.async_validate_action_config( - hass, config + hass, validated_config[action_type] ) except vol.Invalid as err: assert False, f"{action_type} config invalid: {err}"