mirror of
https://github.com/home-assistant/core.git
synced 2025-08-01 17:48:26 +00:00
Forbid extra fields in the vol schema to ensure generated output is correct
This commit is contained in:
parent
afa30be64b
commit
5ba71a4675
@ -59,19 +59,17 @@ STRUCTURE_FIELD_SCHEMA = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_schema(value: dict[str, Any]) -> vol.Schema:
|
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
|
||||||
"""Validate the structure for the generate data task and convert to a vol Schema."""
|
"""Validate the structure fields as a voluptuous Schema."""
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise vol.Invalid("Structure must be a dictionary")
|
raise vol.Invalid("Structure must be a dictionary")
|
||||||
fields = {}
|
fields = {}
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
if not isinstance(v, dict):
|
|
||||||
raise vol.Invalid(f"Structure field '{k}' must be a dictionary")
|
|
||||||
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
|
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
|
||||||
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
|
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
|
||||||
v[CONF_SELECTOR]
|
v[CONF_SELECTOR]
|
||||||
)
|
)
|
||||||
return vol.Schema(fields)
|
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
@ -92,7 +90,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||||
vol.Optional(ATTR_STRUCTURE): vol.All(
|
vol.Optional(ATTR_STRUCTURE): vol.All(
|
||||||
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
||||||
_validate_schema,
|
_validate_structure_fields,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
@ -88,12 +88,12 @@ async def test_generate_data_service(
|
|||||||
assert result["data"] == "Mock result"
|
assert result["data"] == "Mock result"
|
||||||
|
|
||||||
|
|
||||||
async def test_generate_data_service_structure(
|
async def test_generate_data_service_structure_fields(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
mock_ai_task_entity: MockAITaskEntity,
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the entity can generate structured data."""
|
"""Test the entity can generate structured data with a top level object schemea."""
|
||||||
result = await hass.services.async_call(
|
result = await hass.services.async_call(
|
||||||
"ai_task",
|
"ai_task",
|
||||||
"generate_data",
|
"generate_data",
|
||||||
@ -186,9 +186,9 @@ async def test_generate_data_service_structure(
|
|||||||
vol.Invalid,
|
vol.Invalid,
|
||||||
r"required key not provided.*selector.*",
|
r"required key not provided.*selector.*",
|
||||||
),
|
),
|
||||||
(12345, vol.Invalid, r"expected a dictionary.*"),
|
(12345, vol.Invalid, r"xpected a dictionary.*"),
|
||||||
("name", vol.Invalid, r"expected a dictionary.*"),
|
("name", vol.Invalid, r"xpected a dictionary.*"),
|
||||||
(["name"], vol.Invalid, r"expected a dictionary.*"),
|
(["name"], vol.Invalid, r"xpected a dictionary.*"),
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"name": {
|
"name": {
|
||||||
@ -200,6 +200,16 @@ async def test_generate_data_service_structure(
|
|||||||
vol.Invalid,
|
vol.Invalid,
|
||||||
r"extra keys not allowed .*",
|
r"extra keys not allowed .*",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"selector": "invalid-schema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"xpected a dictionary for dictionary.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
ids=(
|
ids=(
|
||||||
"invalid-selector",
|
"invalid-selector",
|
||||||
@ -209,6 +219,7 @@ async def test_generate_data_service_structure(
|
|||||||
"structure-is-str-not-object",
|
"structure-is-str-not-object",
|
||||||
"structure-is-list-not-object",
|
"structure-is-list-not-object",
|
||||||
"extra-fields",
|
"extra-fields",
|
||||||
|
"invalid-selector-schema",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def test_generate_data_service_invalid_structure(
|
async def test_generate_data_service_invalid_structure(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user