Extract attribute names out of vol.Optional when validating entity service schema (#55157)

This commit is contained in:
Børge Nordli 2021-08-25 13:00:11 +02:00 committed by GitHub
parent bd407f3ff4
commit 0d654fa6b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 6 deletions

View File

@ -120,7 +120,7 @@ def path(value: Any) -> str:
# Adapted from:
# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666
def has_at_least_one_key(*keys: str) -> Callable:
def has_at_least_one_key(*keys: Any) -> Callable[[dict], dict]:
"""Validate that at least one key exists."""
def validate(obj: dict) -> dict:
@ -131,12 +131,13 @@ def has_at_least_one_key(*keys: str) -> Callable:
for k in obj:
if k in keys:
return obj
raise vol.Invalid("must contain at least one of {}.".format(", ".join(keys)))
expected = ", ".join(str(k) for k in keys)
raise vol.Invalid(f"must contain at least one of {expected}.")
return validate
def has_at_most_one_key(*keys: str) -> Callable[[dict], dict]:
def has_at_most_one_key(*keys: Any) -> Callable[[dict], dict]:
"""Validate that zero keys exist or one key exists."""
def validate(obj: dict) -> dict:
@ -145,7 +146,8 @@ def has_at_most_one_key(*keys: str) -> Callable[[dict], dict]:
raise vol.Invalid("expected dictionary")
if len(set(keys) & set(obj)) > 1:
raise vol.Invalid("must contain at most one of {}.".format(", ".join(keys)))
expected = ", ".join(str(k) for k in keys)
raise vol.Invalid(f"must contain at most one of {expected}.")
return obj
return validate

View File

@ -388,6 +388,34 @@ def test_service_schema():
cv.SERVICE_SCHEMA(value)
def test_entity_service_schema():
"""Test make_entity_service_schema validation."""
schema = cv.make_entity_service_schema(
{vol.Required("required"): cv.positive_int, vol.Optional("optional"): cv.string}
)
options = (
{},
None,
{"entity_id": "light.kitchen"},
{"optional": "value", "entity_id": "light.kitchen"},
{"required": 1},
{"required": 2, "area_id": "kitchen", "foo": "bar"},
{"required": "str", "area_id": "kitchen"},
)
for value in options:
with pytest.raises(vol.MultipleInvalid):
cv.SERVICE_SCHEMA(value)
options = (
{"required": 1, "entity_id": "light.kitchen"},
{"required": 2, "optional": "value", "device_id": "a_device"},
{"required": 3, "area_id": "kitchen"},
)
for value in options:
schema(value)
def test_slug():
"""Test slug validation."""
schema = vol.Schema(cv.slug)
@ -912,7 +940,7 @@ def test_has_at_most_one_key():
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in ({}, {"beer": None}, {"soda": None}):
for value in ({}, {"beer": None}, {"soda": None}, {vol.Optional("soda"): None}):
schema(value)
@ -924,7 +952,7 @@ def test_has_at_least_one_key():
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in ({"beer": None}, {"soda": None}):
for value in ({"beer": None}, {"soda": None}, {vol.Required("soda"): None}):
schema(value)