diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index cbcfb551dad..32d1c5ec276 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -860,7 +860,10 @@ def removed( def key_value_schemas( - key: str, value_schemas: dict[Hashable, vol.Schema] + key: str, + value_schemas: dict[Hashable, vol.Schema], + default_schema: vol.Schema | None = None, + default_description: str | None = None, ) -> Callable[[Any], dict[Hashable, Any]]: """Create a validator that validates based on a value for specific key. @@ -876,8 +879,15 @@ def key_value_schemas( if isinstance(key_value, Hashable) and key_value in value_schemas: return cast(Dict[Hashable, Any], value_schemas[key_value](value)) + if default_schema: + with contextlib.suppress(vol.Invalid): + return cast(Dict[Hashable, Any], default_schema(value)) + + alternatives = ", ".join(str(key) for key in value_schemas) + if default_description: + alternatives += ", " + default_description raise vol.Invalid( - f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(str(key) for key in value_schemas)}" + f"Unexpected value for {key}: '{key_value}'. Expected {alternatives}" ) return key_value_validator @@ -1207,6 +1217,40 @@ CONDITION_SCHEMA: vol.Schema = vol.Schema( ) ) + +dynamic_template_condition_action = vol.All( + vol.Schema( + {**CONDITION_BASE_SCHEMA, vol.Required(CONF_CONDITION): dynamic_template} + ), + lambda config: { + **config, + CONF_VALUE_TEMPLATE: config[CONF_CONDITION], + CONF_CONDITION: "template", + }, +) + + +CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema( + key_value_schemas( + CONF_CONDITION, + { + "and": AND_CONDITION_SCHEMA, + "device": DEVICE_CONDITION_SCHEMA, + "not": NOT_CONDITION_SCHEMA, + "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, + "or": OR_CONDITION_SCHEMA, + "state": STATE_CONDITION_SCHEMA, + "sun": SUN_CONDITION_SCHEMA, + "template": TEMPLATE_CONDITION_SCHEMA, + "time": TIME_CONDITION_SCHEMA, + "trigger": TRIGGER_CONDITION_SCHEMA, + "zone": ZONE_CONDITION_SCHEMA, + }, + dynamic_template_condition_action, + "a valid template", + ) +) + TRIGGER_BASE_SCHEMA = vol.Schema( {vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str} ) @@ -1352,7 +1396,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA, SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA, SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA, - SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA, + SCRIPT_ACTION_CHECK_CONDITION: CONDITION_ACTION_SCHEMA, SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA, SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA, diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 8327eb2e320..0c32ae5eddf 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -1184,6 +1184,45 @@ def test_key_value_schemas(): schema({"mode": mode, "data": data}) +def test_key_value_schemas_with_default(): + """Test key value schemas.""" + schema = vol.Schema( + cv.key_value_schemas( + "mode", + { + "number": vol.Schema({"mode": "number", "data": int}), + "string": vol.Schema({"mode": "string", "data": str}), + }, + vol.Schema({"mode": cv.dynamic_template}), + "a cool template", + ) + ) + + with pytest.raises(vol.Invalid) as excinfo: + schema(True) + assert str(excinfo.value) == "Expected a dictionary" + + for mode in None, {"a": "dict"}, "invalid": + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": mode}) + assert ( + str(excinfo.value) + == f"Unexpected value for mode: '{mode}'. Expected number, string, a cool template" + ) + + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": "number", "data": "string-value"}) + assert str(excinfo.value) == "expected int for dictionary value @ data['data']" + + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": "string", "data": 1}) + assert str(excinfo.value) == "expected str for dictionary value @ data['data']" + + for mode, data in (("number", 1), ("string", "hello")): + schema({"mode": mode, "data": data}) + schema({"mode": "{{ 1 + 1}}"}) + + def test_script(caplog): """Test script validation is user friendly.""" for data, msg in ( diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index b0a93c85b2b..2ee4213a688 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1501,6 +1501,61 @@ async def test_condition_basic(hass, caplog): ) +async def test_shorthand_template_condition(hass, caplog): + """Test if we can use shorthand template conditions in a script.""" + event = "test_event" + events = async_capture_events(hass, event) + alias = "condition step" + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + { + "alias": alias, + "condition": "{{ states.test.entity.state == 'hello' }}", + }, + {"event": event}, + ] + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + hass.states.async_set("test.entity", "hello") + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert f"Test condition {alias}: True" in caplog.text + caplog.clear() + assert len(events) == 2 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{"result": {"entities": ["test.entity"], "result": True}}], + "2": [{"result": {"event": "test_event", "event_data": {}}}], + } + ) + + hass.states.async_set("test.entity", "goodbye") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert f"Test condition {alias}: False" in caplog.text + assert len(events) == 3 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [ + { + "error_type": script._StopScript, + "result": {"entities": ["test.entity"], "result": False}, + } + ], + }, + expected_script_execution="aborted", + ) + + async def test_condition_validation(hass, caplog): """Test if we can use conditions which validate late in a script.""" registry = er.async_get(hass)