diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 1ff2644fa58..8e4454751bf 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -704,6 +704,30 @@ def deprecated( return validator +def key_value_schemas( + key: str, value_schemas: Dict[str, vol.Schema] +) -> Callable[[Any], Dict[str, Any]]: + """Create a validator that validates based on a value for specific key. + + This gives better error messages. + """ + + def key_value_validator(value: Any) -> Dict[str, Any]: + if not isinstance(value, dict): + raise vol.Invalid("Expected a dictionary") + + key_value = value.get(key) + + if key_value not in value_schemas: + raise vol.Invalid( + f"Unexpected key {key_value}. Expected {', '.join(value_schemas)}" + ) + + return cast(Dict[str, Any], value_schemas[key_value](value)) + + return key_value_validator + + # Validator helpers @@ -899,16 +923,19 @@ DEVICE_CONDITION_BASE_SCHEMA = vol.Schema( DEVICE_CONDITION_SCHEMA = DEVICE_CONDITION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) -CONDITION_SCHEMA: vol.Schema = vol.Any( - NUMERIC_STATE_CONDITION_SCHEMA, - STATE_CONDITION_SCHEMA, - SUN_CONDITION_SCHEMA, - TEMPLATE_CONDITION_SCHEMA, - TIME_CONDITION_SCHEMA, - ZONE_CONDITION_SCHEMA, - AND_CONDITION_SCHEMA, - OR_CONDITION_SCHEMA, - DEVICE_CONDITION_SCHEMA, +CONDITION_SCHEMA: vol.Schema = key_value_schemas( + CONF_CONDITION, + { + "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, + "state": STATE_CONDITION_SCHEMA, + "sun": SUN_CONDITION_SCHEMA, + "template": TEMPLATE_CONDITION_SCHEMA, + "time": TIME_CONDITION_SCHEMA, + "zone": ZONE_CONDITION_SCHEMA, + "and": AND_CONDITION_SCHEMA, + "or": OR_CONDITION_SCHEMA, + "device": DEVICE_CONDITION_SCHEMA, + }, ) _SCRIPT_DELAY_SCHEMA = vol.Schema( diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 9b6aa6b812d..e94fa202ce6 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -987,3 +987,36 @@ def test_uuid4_hex(caplog): _hex = uuid.uuid4().hex assert schema(_hex) == _hex assert schema(_hex.upper()) == _hex + + +def test_key_value_schemas(): + """Test key value schemas.""" + schema = vol.Schema( + cv.key_value_schemas( + "mode", + { + "number": vol.Schema({"mode": "number", "data": int}), + "string": vol.Schema({"mode": "string", "data": str}), + }, + ) + ) + + with pytest.raises(vol.Invalid) as excinfo: + schema(True) + assert str(excinfo.value) == "Expected a dictionary" + + for mode in None, "invalid": + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": mode}) + assert str(excinfo.value) == f"Unexpected key {mode}. Expected number, string" + + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": "number", "data": "string-value"}) + assert str(excinfo.value) == "expected int for dictionary value @ data['data']" + + with pytest.raises(vol.Invalid) as excinfo: + schema({"mode": "string", "data": 1}) + assert str(excinfo.value) == "expected str for dictionary value @ data['data']" + + for mode, data in (("number", 1), ("string", "hello")): + schema({"mode": mode, "data": data})