Forbid extra fields in the vol schema to ensure generated output is correct

This commit is contained in:
Allen Porter 2025-07-04 00:35:03 +00:00
parent afa30be64b
commit 5ba71a4675
2 changed files with 20 additions and 11 deletions

View File

@ -59,19 +59,17 @@ STRUCTURE_FIELD_SCHEMA = vol.Schema(
)
def _validate_schema(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure for the generate data task and convert to a vol Schema."""
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure fields as a voluptuous Schema."""
if not isinstance(value, dict):
raise vol.Invalid("Structure must be a dictionary")
fields = {}
for k, v in value.items():
if not isinstance(v, dict):
raise vol.Invalid(f"Structure field '{k}' must be a dictionary")
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
v[CONF_SELECTOR]
)
return vol.Schema(fields)
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -92,7 +90,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_STRUCTURE): vol.All(
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_schema,
_validate_structure_fields,
),
}
),

View File

@ -88,12 +88,12 @@ async def test_generate_data_service(
assert result["data"] == "Mock result"
async def test_generate_data_service_structure(
async def test_generate_data_service_structure_fields(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
"""Test the entity can generate structured data with a top level object schemea."""
result = await hass.services.async_call(
"ai_task",
"generate_data",
@ -186,9 +186,9 @@ async def test_generate_data_service_structure(
vol.Invalid,
r"required key not provided.*selector.*",
),
(12345, vol.Invalid, r"expected a dictionary.*"),
("name", vol.Invalid, r"expected a dictionary.*"),
(["name"], vol.Invalid, r"expected a dictionary.*"),
(12345, vol.Invalid, r"xpected a dictionary.*"),
("name", vol.Invalid, r"xpected a dictionary.*"),
(["name"], vol.Invalid, r"xpected a dictionary.*"),
(
{
"name": {
@ -200,6 +200,16 @@ async def test_generate_data_service_structure(
vol.Invalid,
r"extra keys not allowed .*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": "invalid-schema",
},
},
vol.Invalid,
r"xpected a dictionary for dictionary.",
),
],
ids=(
"invalid-selector",
@ -209,6 +219,7 @@ async def test_generate_data_service_structure(
"structure-is-str-not-object",
"structure-is-list-not-object",
"extra-fields",
"invalid-selector-schema",
),
)
async def test_generate_data_service_invalid_structure(