Improve condition schema validation (#144793)

This commit is contained in:
Erik Montnemery
2025-09-16 10:44:26 +02:00
committed by GitHub
parent 4bba167ab3
commit 0f372f4b47
3 changed files with 105 additions and 10 deletions

View File

@@ -1108,11 +1108,21 @@ def key_value_schemas(
value_schemas: ValueSchemas, value_schemas: ValueSchemas,
default_schema: VolSchemaType | Callable[[Any], dict[str, Any]] | None = None, default_schema: VolSchemaType | Callable[[Any], dict[str, Any]] | None = None,
default_description: str | None = None, default_description: str | None = None,
list_alternatives: bool = True,
) -> Callable[[Any], dict[Hashable, Any]]: ) -> Callable[[Any], dict[Hashable, Any]]:
"""Create a validator that validates based on a value for specific key. """Create a validator that validates based on a value for specific key.
This gives better error messages. This gives better error messages.
default_schema: An optional schema to use if the key value is not in value_schemas.
default_description: A description of what is expected by the default schema, this
will be added to the error message.
list_alternatives: If True, list the keys in `value_schemas` in the error message.
""" """
if not list_alternatives and not default_description:
raise ValueError(
"default_description must be provided if list_alternatives is False"
)
def key_value_validator(value: Any) -> dict[Hashable, Any]: def key_value_validator(value: Any) -> dict[Hashable, Any]:
if not isinstance(value, dict): if not isinstance(value, dict):
@@ -1127,9 +1137,13 @@ def key_value_schemas(
with contextlib.suppress(vol.Invalid): with contextlib.suppress(vol.Invalid):
return cast(dict[Hashable, Any], default_schema(value)) return cast(dict[Hashable, Any], default_schema(value))
alternatives = ", ".join(str(alternative) for alternative in value_schemas) if list_alternatives:
if default_description: alternatives = ", ".join(str(alternative) for alternative in value_schemas)
alternatives = f"{alternatives}, {default_description}" if default_description:
alternatives = f"{alternatives}, {default_description}"
else:
# mypy does not understand that default_description is not None here
alternatives = default_description # type: ignore[assignment]
raise vol.Invalid( raise vol.Invalid(
f"Unexpected value for {key}: '{key_value}'. Expected {alternatives}" f"Unexpected value for {key}: '{key_value}'. Expected {alternatives}"
) )
@@ -1753,7 +1767,7 @@ def _base_condition_validator(value: Any) -> Any:
vol.Schema( vol.Schema(
{ {
**CONDITION_BASE_SCHEMA, **CONDITION_BASE_SCHEMA,
CONF_CONDITION: vol.NotIn(BUILT_IN_CONDITIONS), CONF_CONDITION: vol.All(str, vol.NotIn(BUILT_IN_CONDITIONS)),
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
)(value) )(value)
@@ -1768,6 +1782,8 @@ CONDITION_SCHEMA: vol.Schema = vol.Schema(
CONF_CONDITION, CONF_CONDITION,
BUILT_IN_CONDITIONS, BUILT_IN_CONDITIONS,
_base_condition_validator, _base_condition_validator,
"a condition, a list of conditions or a valid template",
list_alternatives=False,
), ),
), ),
dynamic_template_condition, dynamic_template_condition,
@@ -1799,7 +1815,8 @@ CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
dynamic_template_condition_action, dynamic_template_condition_action,
_base_condition_validator, _base_condition_validator,
), ),
"a list of conditions or a valid template", "a condition, a list of conditions or a valid template",
list_alternatives=False,
), ),
) )
) )

View File

@@ -82,11 +82,26 @@ def assert_condition_trace(expected):
assert_element(condition_trace[key][index], element, path) assert_element(condition_trace[key][index], element, path)
async def test_invalid_condition(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"""Test if invalid condition raises.""" ("config", "error"),
with pytest.raises(HomeAssistantError): [
await condition.async_from_config( (
hass, {"condition": 123},
"Unexpected value for condition: '123'. Expected a condition, "
"a list of conditions or a valid template",
)
],
)
async def test_invalid_condition(hass: HomeAssistant, config: dict, error: str) -> None:
"""Test if validating an invalid condition raises."""
with pytest.raises(vol.Invalid, match=error):
cv.CONDITION_SCHEMA(config)
@pytest.mark.parametrize(
("config", "error"),
[
(
{ {
"condition": "invalid", "condition": "invalid",
"conditions": [ "conditions": [
@@ -97,7 +112,15 @@ async def test_invalid_condition(hass: HomeAssistant) -> None:
}, },
], ],
}, },
'Invalid condition "invalid" specified',
) )
],
)
async def test_unknown_condition(hass: HomeAssistant, config: dict, error: str) -> None:
"""Test if creating an unknown condition raises."""
config = cv.CONDITION_SCHEMA(config)
with pytest.raises(HomeAssistantError, match=error):
await condition.async_from_config(hass, config)
async def test_and_condition(hass: HomeAssistant) -> None: async def test_and_condition(hass: HomeAssistant) -> None:

View File

@@ -1455,6 +1455,56 @@ def test_key_value_schemas_with_default() -> None:
schema({"mode": "{{ 1 + 1}}"}) schema({"mode": "{{ 1 + 1}}"})
@pytest.mark.usefixtures("hass")
def test_key_value_schemas_with_default_no_list_alternatives() -> None:
"""Test key value schemas."""
schema = vol.Schema(
cv.key_value_schemas(
"mode",
{
"number": vol.Schema({"mode": "number", "data": int}),
"string": vol.Schema({"mode": "string", "data": str}),
},
vol.Schema({"mode": cv.dynamic_template}),
"a cool template",
list_alternatives=False,
)
)
with pytest.raises(vol.Invalid) as excinfo:
schema(True)
assert str(excinfo.value) == "Expected a dictionary"
for mode in None, {"a": "dict"}, "invalid":
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": mode})
assert (
str(excinfo.value)
== f"Unexpected value for mode: '{mode}'. Expected a cool template"
)
@pytest.mark.usefixtures("hass")
def test_key_value_schemas_without_default_no_list_alternatives() -> None:
"""Test key value schemas."""
with pytest.raises(ValueError) as excinfo:
vol.Schema(
cv.key_value_schemas(
"mode",
{
"number": vol.Schema({"mode": "number", "data": int}),
"string": vol.Schema({"mode": "string", "data": str}),
},
vol.Schema({"mode": cv.dynamic_template}),
list_alternatives=False,
)
)
assert (
str(excinfo.value)
== "default_description must be provided if list_alternatives is False"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("config", "error"), ("config", "error"),
[ [
@@ -1462,6 +1512,11 @@ def test_key_value_schemas_with_default() -> None:
({"wait_template": "{{ invalid"}, "invalid template"), ({"wait_template": "{{ invalid"}, "invalid template"),
# The validation error message could be improved to explain that this is not # The validation error message could be improved to explain that this is not
# a valid shorthand template # a valid shorthand template
(
{"condition": 123},
"Unexpected value for condition: '123'. Expected a condition, a list of "
"conditions or a valid template",
),
( (
{"condition": "not", "conditions": "not a dynamic template"}, {"condition": "not", "conditions": "not a dynamic template"},
"Expected a dictionary", "Expected a dictionary",