From 5ba71a467530c11e9a740b992c7b6330e3b536fe Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Fri, 4 Jul 2025 00:35:03 +0000 Subject: [PATCH] Forbid extra fields in the vol schema to ensure generated output is correct --- homeassistant/components/ai_task/__init__.py | 10 ++++------ tests/components/ai_task/test_init.py | 21 +++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 9bdb5d40d31..95c080cc472 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -59,19 +59,17 @@ STRUCTURE_FIELD_SCHEMA = vol.Schema( ) -def _validate_schema(value: dict[str, Any]) -> vol.Schema: - """Validate the structure for the generate data task and convert to a vol Schema.""" +def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema: + """Validate the structure fields as a voluptuous Schema.""" if not isinstance(value, dict): raise vol.Invalid("Structure must be a dictionary") fields = {} for k, v in value.items(): - if not isinstance(v, dict): - raise vol.Invalid(f"Structure field '{k}' must be a dictionary") field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector( v[CONF_SELECTOR] ) - return vol.Schema(fields) + return vol.Schema(fields, extra=vol.PREVENT_EXTRA) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -92,7 +90,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: vol.Required(ATTR_INSTRUCTIONS): cv.string, vol.Optional(ATTR_STRUCTURE): vol.All( vol.Schema({str: STRUCTURE_FIELD_SCHEMA}), - _validate_schema, + _validate_structure_fields, ), } ), diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index dcc7290c5ae..02980cd699d 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -88,12 +88,12 @@ async def test_generate_data_service( assert result["data"] == "Mock result" -async def test_generate_data_service_structure( +async def test_generate_data_service_structure_fields( hass: HomeAssistant, init_components: None, mock_ai_task_entity: MockAITaskEntity, ) -> None: - """Test the entity can generate structured data.""" + """Test the entity can generate structured data with a top level object schemea.""" result = await hass.services.async_call( "ai_task", "generate_data", @@ -186,9 +186,9 @@ async def test_generate_data_service_structure( vol.Invalid, r"required key not provided.*selector.*", ), - (12345, vol.Invalid, r"expected a dictionary.*"), - ("name", vol.Invalid, r"expected a dictionary.*"), - (["name"], vol.Invalid, r"expected a dictionary.*"), + (12345, vol.Invalid, r"xpected a dictionary.*"), + ("name", vol.Invalid, r"xpected a dictionary.*"), + (["name"], vol.Invalid, r"xpected a dictionary.*"), ( { "name": { @@ -200,6 +200,16 @@ async def test_generate_data_service_structure( vol.Invalid, r"extra keys not allowed .*", ), + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + "selector": "invalid-schema", + }, + }, + vol.Invalid, + r"xpected a dictionary for dictionary.", + ), ], ids=( "invalid-selector", @@ -209,6 +219,7 @@ async def test_generate_data_service_structure( "structure-is-str-not-object", "structure-is-list-not-object", "extra-fields", + "invalid-selector-schema", ), ) async def test_generate_data_service_invalid_structure(