Support shorthand templates in condition actions (#61177)

* Support shorthand templates in condition actions

* Fix validation message

* Fix tests
This commit is contained in:
Erik Montnemery 2021-12-21 12:19:31 +01:00 committed by GitHub
parent 4b30c9631f
commit e2fca2e305
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 3 deletions

View File

@ -860,7 +860,10 @@ def removed(
def key_value_schemas(
key: str, value_schemas: dict[Hashable, vol.Schema]
key: str,
value_schemas: dict[Hashable, vol.Schema],
default_schema: vol.Schema | None = None,
default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]:
"""Create a validator that validates based on a value for specific key.
@ -876,8 +879,15 @@ def key_value_schemas(
if isinstance(key_value, Hashable) and key_value in value_schemas:
return cast(Dict[Hashable, Any], value_schemas[key_value](value))
if default_schema:
with contextlib.suppress(vol.Invalid):
return cast(Dict[Hashable, Any], default_schema(value))
alternatives = ", ".join(str(key) for key in value_schemas)
if default_description:
alternatives += ", " + default_description
raise vol.Invalid(
f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(str(key) for key in value_schemas)}"
f"Unexpected value for {key}: '{key_value}'. Expected {alternatives}"
)
return key_value_validator
@ -1207,6 +1217,40 @@ CONDITION_SCHEMA: vol.Schema = vol.Schema(
)
)
dynamic_template_condition_action = vol.All(
vol.Schema(
{**CONDITION_BASE_SCHEMA, vol.Required(CONF_CONDITION): dynamic_template}
),
lambda config: {
**config,
CONF_VALUE_TEMPLATE: config[CONF_CONDITION],
CONF_CONDITION: "template",
},
)
CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
key_value_schemas(
CONF_CONDITION,
{
"and": AND_CONDITION_SCHEMA,
"device": DEVICE_CONDITION_SCHEMA,
"not": NOT_CONDITION_SCHEMA,
"numeric_state": NUMERIC_STATE_CONDITION_SCHEMA,
"or": OR_CONDITION_SCHEMA,
"state": STATE_CONDITION_SCHEMA,
"sun": SUN_CONDITION_SCHEMA,
"template": TEMPLATE_CONDITION_SCHEMA,
"time": TIME_CONDITION_SCHEMA,
"trigger": TRIGGER_CONDITION_SCHEMA,
"zone": ZONE_CONDITION_SCHEMA,
},
dynamic_template_condition_action,
"a valid template",
)
)
TRIGGER_BASE_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str}
)
@ -1352,7 +1396,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,
SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_ACTION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,

View File

@ -1184,6 +1184,45 @@ def test_key_value_schemas():
schema({"mode": mode, "data": data})
def test_key_value_schemas_with_default():
"""Test key value schemas."""
schema = vol.Schema(
cv.key_value_schemas(
"mode",
{
"number": vol.Schema({"mode": "number", "data": int}),
"string": vol.Schema({"mode": "string", "data": str}),
},
vol.Schema({"mode": cv.dynamic_template}),
"a cool template",
)
)
with pytest.raises(vol.Invalid) as excinfo:
schema(True)
assert str(excinfo.value) == "Expected a dictionary"
for mode in None, {"a": "dict"}, "invalid":
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": mode})
assert (
str(excinfo.value)
== f"Unexpected value for mode: '{mode}'. Expected number, string, a cool template"
)
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": "number", "data": "string-value"})
assert str(excinfo.value) == "expected int for dictionary value @ data['data']"
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": "string", "data": 1})
assert str(excinfo.value) == "expected str for dictionary value @ data['data']"
for mode, data in (("number", 1), ("string", "hello")):
schema({"mode": mode, "data": data})
schema({"mode": "{{ 1 + 1}}"})
def test_script(caplog):
"""Test script validation is user friendly."""
for data, msg in (

View File

@ -1501,6 +1501,61 @@ async def test_condition_basic(hass, caplog):
)
async def test_shorthand_template_condition(hass, caplog):
"""Test if we can use shorthand template conditions in a script."""
event = "test_event"
events = async_capture_events(hass, event)
alias = "condition step"
sequence = cv.SCRIPT_SCHEMA(
[
{"event": event},
{
"alias": alias,
"condition": "{{ states.test.entity.state == 'hello' }}",
},
{"event": event},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
hass.states.async_set("test.entity", "hello")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
assert f"Test condition {alias}: True" in caplog.text
caplog.clear()
assert len(events) == 2
assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {"entities": ["test.entity"], "result": True}}],
"2": [{"result": {"event": "test_event", "event_data": {}}}],
}
)
hass.states.async_set("test.entity", "goodbye")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()
assert f"Test condition {alias}: False" in caplog.text
assert len(events) == 3
assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [
{
"error_type": script._StopScript,
"result": {"entities": ["test.entity"], "result": False},
}
],
},
expected_script_execution="aborted",
)
async def test_condition_validation(hass, caplog):
"""Test if we can use conditions which validate late in a script."""
registry = er.async_get(hass)