Correct validation of conditions in scripts and automations (#60890)

* Correct validation of conditions in scripts and automations

* Fix test
This commit is contained in:
Erik Montnemery 2021-12-03 18:08:28 +01:00 committed by GitHub
parent f57d42a9e8
commit 17dc609363
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 237 additions and 6 deletions

View File

@ -967,6 +967,15 @@ async def async_validate_condition_config(
return config
async def async_validate_conditions_config(
hass: HomeAssistant, conditions: list[ConfigType | Template]
) -> list[ConfigType | Template]:
"""Validate config."""
return await asyncio.gather(
*(async_validate_condition_config(hass, cond) for cond in conditions)
)
@callback
def async_extract_entities(config: ConfigType | Template) -> set[str]:
"""Extract entities from a condition."""

View File

@ -270,6 +270,16 @@ async def async_validate_action_config(
)
elif action_type == cv.SCRIPT_ACTION_REPEAT:
if CONF_UNTIL in config[CONF_REPEAT]:
conditions = await condition.async_validate_conditions_config(
hass, config[CONF_REPEAT][CONF_UNTIL]
)
config[CONF_REPEAT][CONF_UNTIL] = conditions
if CONF_WHILE in config[CONF_REPEAT]:
conditions = await condition.async_validate_conditions_config(
hass, config[CONF_REPEAT][CONF_WHILE]
)
config[CONF_REPEAT][CONF_WHILE] = conditions
config[CONF_REPEAT][CONF_SEQUENCE] = await async_validate_actions_config(
hass, config[CONF_REPEAT][CONF_SEQUENCE]
)
@ -281,6 +291,10 @@ async def async_validate_action_config(
)
for choose_conf in config[CONF_CHOOSE]:
conditions = await condition.async_validate_conditions_config(
hass, choose_conf[CONF_CONDITIONS]
)
choose_conf[CONF_CONDITIONS] = conditions
choose_conf[CONF_SEQUENCE] = await async_validate_actions_config(
hass, choose_conf[CONF_SEQUENCE]
)

View File

@ -56,8 +56,11 @@ def compare_trigger_item(actual_trigger, expected_trigger):
assert actual_trigger["description"] == expected_trigger["description"]
def compare_result_item(key, actual, expected):
"""Compare an item in the result dict."""
def compare_result_item(key, actual, expected, path):
"""Compare an item in the result dict.
Note: Unused variable 'path' is passed to get helpful errors from pytest.
"""
if key == "wait" and (expected.get("trigger") is not None):
assert "trigger" in actual
expected_trigger = expected.pop("trigger")
@ -78,7 +81,7 @@ def assert_element(trace_element, expected_element, path):
# The redundant set operation gives helpful errors from pytest
assert not set(expected_result) - set(trace_element._result or {})
for result_key, result in expected_result.items():
compare_result_item(result_key, trace_element._result[result_key], result)
compare_result_item(result_key, trace_element._result[result_key], result, path)
assert trace_element._result[result_key] == result
# Check for unexpected items in trace_element
@ -1819,6 +1822,126 @@ async def test_repeat_conditional(hass, condition, direct_template):
assert event.data.get("index") == index + 1
async def test_repeat_until_condition_validation(hass, caplog):
"""Test if we can use conditions in repeat until conditions which validate late."""
registry = er.async_get(hass)
entry = registry.async_get_or_create(
"test", "hue", "1234", suggested_object_id="entity"
)
assert entry.entity_id == "test.entity"
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"repeat": {
"sequence": [
{"event": event},
],
"until": [
{
"condition": "state",
"entity_id": entry.id,
"state": "hello",
}
],
}
},
]
)
hass.states.async_set("test.entity", "hello")
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
caplog.clear()
assert len(events) == 1
assert_action_trace(
{
"0": [{"result": {}}],
"0/repeat/sequence/0": [
{
"result": {"event": "test_event", "event_data": {}},
"variables": {"repeat": {"first": True, "index": 1}},
}
],
"0/repeat": [
{
"result": {"result": True},
"variables": {"repeat": {"first": True, "index": 1}},
}
],
"0/repeat/until/0": [{"result": {"result": True}}],
"0/repeat/until/0/entity_id/0": [
{"result": {"result": True, "state": "hello", "wanted_state": "hello"}}
],
}
)
async def test_repeat_while_condition_validation(hass, caplog):
"""Test if we can use conditions in repeat while conditions which validate late."""
registry = er.async_get(hass)
entry = registry.async_get_or_create(
"test", "hue", "1234", suggested_object_id="entity"
)
assert entry.entity_id == "test.entity"
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"repeat": {
"sequence": [
{"event": event},
],
"while": [
{
"condition": "state",
"entity_id": entry.id,
"state": "hello",
}
],
}
},
]
)
hass.states.async_set("test.entity", "goodbye")
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
caplog.clear()
assert len(events) == 0
assert_action_trace(
{
"0": [{"result": {}}],
"0/repeat": [
{
"result": {"result": False},
"variables": {"repeat": {"first": True, "index": 1}},
}
],
"0/repeat/while/0": [{"result": {"result": False}}],
"0/repeat/while/0/entity_id/0": [
{
"result": {
"result": False,
"state": "goodbye",
"wanted_state": "hello",
}
}
],
}
)
@pytest.mark.parametrize("condition", ["while", "until"])
async def test_repeat_var_in_condition(hass, condition):
"""Test repeat action w/ while option."""
@ -2182,6 +2305,88 @@ async def test_choose(hass, caplog, var, result):
assert_action_trace(expected_trace)
async def test_choose_condition_validation(hass, caplog):
"""Test if we can use conditions in choose actions which validate late."""
registry = er.async_get(hass)
entry = registry.async_get_or_create(
"test", "hue", "1234", suggested_object_id="entity"
)
assert entry.entity_id == "test.entity"
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{"event": event},
{
"choose": [
{
"alias": "choice one",
"conditions": {
"condition": "state",
"entity_id": entry.id,
"state": "hello",
},
"sequence": {
"alias": "sequence one",
"event": event,
"event_data": {"choice": "first"},
},
},
]
},
]
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
hass.states.async_set("test.entity", "hello")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
caplog.clear()
assert len(events) == 2
assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {"choice": 0}}],
"1/choose/0": [{"result": {"result": True}}],
"1/choose/0/conditions/0": [{"result": {"result": True}}],
"1/choose/0/conditions/0/entity_id/0": [
{"result": {"result": True, "state": "hello", "wanted_state": "hello"}}
],
"1/choose/0/sequence/0": [
{"result": {"event": "test_event", "event_data": {"choice": "first"}}}
],
}
)
hass.states.async_set("test.entity", "goodbye")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
assert len(events) == 3
assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {}}],
"1/choose/0": [{"result": {"result": False}}],
"1/choose/0/conditions/0": [{"result": {"result": False}}],
"1/choose/0/conditions/0/entity_id/0": [
{
"result": {
"result": False,
"state": "goodbye",
"wanted_state": "hello",
}
}
],
},
)
@pytest.mark.parametrize(
"action",
[
@ -3132,7 +3337,9 @@ async def test_validate_action_config(hass):
},
cv.SCRIPT_ACTION_FIRE_EVENT: {"event": "my_event"},
cv.SCRIPT_ACTION_CHECK_CONDITION: {
"condition": "{{ states.light.kitchen.state == 'on' }}"
"condition": "state",
"entity_id": "light.kitchen",
"state": "on",
},
cv.SCRIPT_ACTION_DEVICE_AUTOMATION: templated_device_action("device"),
cv.SCRIPT_ACTION_ACTIVATE_SCENE: {"scene": "scene.relax"},
@ -3145,7 +3352,7 @@ async def test_validate_action_config(hass):
cv.SCRIPT_ACTION_CHOOSE: {
"choose": [
{
"condition": "{{ states.light.kitchen.state == 'on' }}",
"conditions": "{{ states.light.kitchen.state == 'on' }}",
"sequence": [templated_device_action("choose_event")],
}
],
@ -3182,8 +3389,9 @@ async def test_validate_action_config(hass):
for action_type, config in configs.items():
assert cv.determine_script_action(config) == action_type
try:
validated_config[action_type] = cv.ACTION_TYPE_SCHEMAS[action_type](config)
validated_config[action_type] = await script.async_validate_action_config(
hass, config
hass, validated_config[action_type]
)
except vol.Invalid as err:
assert False, f"{action_type} config invalid: {err}"