diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 692e5d410ae..95c080cc472 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -1,11 +1,12 @@ """Integration to offer AI tasks to Home Assistant.""" import logging +from typing import Any import voluptuous as vol from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR from homeassistant.core import ( HassJobType, HomeAssistant, @@ -14,12 +15,14 @@ from homeassistant.core import ( SupportsResponse, callback, ) -from homeassistant.helpers import config_validation as cv, storage +from homeassistant.helpers import config_validation as cv, selector, storage from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from .const import ( ATTR_INSTRUCTIONS, + ATTR_REQUIRED, + ATTR_STRUCTURE, ATTR_TASK_NAME, DATA_COMPONENT, DATA_PREFERENCES, @@ -47,6 +50,27 @@ _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +STRUCTURE_FIELD_SCHEMA = vol.Schema( + { + vol.Optional(CONF_DESCRIPTION): str, + vol.Optional(ATTR_REQUIRED): bool, + vol.Required(CONF_SELECTOR): selector.validate_selector, + } +) + + +def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema: + """Validate the structure fields as a voluptuous Schema.""" + if not isinstance(value, dict): + raise vol.Invalid("Structure must be a dictionary") + fields = {} + for k, v in value.items(): + field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional + fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector( + v[CONF_SELECTOR] + ) + return vol.Schema(fields, extra=vol.PREVENT_EXTRA) + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" @@ -64,6 +88,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: vol.Required(ATTR_TASK_NAME): cv.string, vol.Optional(ATTR_ENTITY_ID): cv.entity_id, vol.Required(ATTR_INSTRUCTIONS): cv.string, + vol.Optional(ATTR_STRUCTURE): vol.All( + vol.Schema({str: STRUCTURE_FIELD_SCHEMA}), + _validate_structure_fields, + ), } ), supports_response=SupportsResponse.ONLY, diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index 8b612e90560..fa8702ed69e 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data" ATTR_INSTRUCTIONS: Final = "instructions" ATTR_TASK_NAME: Final = "task_name" +ATTR_STRUCTURE: Final = "structure" +ATTR_REQUIRED: Final = "required" DEFAULT_SYSTEM_PROMPT = ( "You are a Home Assistant expert and help users with their tasks." diff --git a/homeassistant/components/ai_task/services.yaml b/homeassistant/components/ai_task/services.yaml index a531ca599b1..d55b0e60fac 100644 --- a/homeassistant/components/ai_task/services.yaml +++ b/homeassistant/components/ai_task/services.yaml @@ -17,3 +17,9 @@ generate_data: domain: ai_task supported_features: - ai_task.AITaskEntityFeature.GENERATE_DATA + structure: + advanced: true + required: false + example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }' + selector: + object: diff --git a/homeassistant/components/ai_task/strings.json b/homeassistant/components/ai_task/strings.json index 877174de681..92106c3baca 100644 --- a/homeassistant/components/ai_task/strings.json +++ b/homeassistant/components/ai_task/strings.json @@ -15,6 +15,10 @@ "entity_id": { "name": "Entity ID", "description": "Entity ID to run the task on. If not provided, the preferred entity will be used." + }, + "structure": { + "name": "Structured output", + "description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field." } } } diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 2e546897602..b6defbfad31 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -5,6 +5,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any +import voluptuous as vol + from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -17,6 +19,7 @@ async def async_generate_data( task_name: str, entity_id: str | None = None, instructions: str, + structure: vol.Schema | None = None, ) -> GenDataTaskResult: """Run a task in the AI Task integration.""" if entity_id is None: @@ -38,6 +41,7 @@ async def async_generate_data( GenDataTask( name=task_name, instructions=instructions, + structure=structure, ) ) @@ -52,6 +56,9 @@ class GenDataTask: instructions: str """Instructions on what needs to be done.""" + structure: vol.Schema | None = None + """Optional structure for the data to be generated.""" + def __str__(self) -> str: """Return task as a string.""" return f"" diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 51d9c97ceeb..c7d4a26c86e 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[ def _base_components() -> dict[str, ModuleType]: """Return a cached lookup of base components.""" from homeassistant.components import ( # noqa: PLC0415 + ai_task, alarm_control_panel, assist_satellite, calendar, @@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]: ) return { + "ai_task": ai_task, "alarm_control_panel": alarm_control_panel, "assist_satellite": assist_satellite, "calendar": calendar, diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py index 7efbd1ffcdb..e80e70ddaed 100644 --- a/tests/components/ai_task/conftest.py +++ b/tests/components/ai_task/conftest.py @@ -1,5 +1,7 @@ """Test helpers for AI Task integration.""" +import json + import pytest from homeassistant.components.ai_task import ( @@ -45,12 +47,18 @@ class MockAITaskEntity(AITaskEntity): ) -> GenDataTaskResult: """Mock handling of generate data task.""" self.mock_generate_data_tasks.append(task) + if task.structure is not None: + data = {"name": "Tracy Chen", "age": 30} + data_chat_log = json.dumps(data) + else: + data = "Mock result" + data_chat_log = data chat_log.async_add_assistant_content_without_tools( - AssistantContent(self.entity_id, "Mock result") + AssistantContent(self.entity_id, data_chat_log) ) return GenDataTaskResult( conversation_id=chat_log.conversation_id, - data="Mock result", + data=data, ) diff --git a/tests/components/ai_task/test_entity.py b/tests/components/ai_task/test_entity.py index 3ed1c393588..08f1bb42836 100644 --- a/tests/components/ai_task/test_entity.py +++ b/tests/components/ai_task/test_entity.py @@ -1,10 +1,12 @@ """Tests for the AI Task entity model.""" from freezegun import freeze_time +import voluptuous as vol from homeassistant.components.ai_task import async_generate_data from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant +from homeassistant.helpers import selector from .conftest import TEST_ENTITY_ID, MockAITaskEntity @@ -37,3 +39,40 @@ async def test_state_generate_data( assert mock_ai_task_entity.mock_generate_data_tasks task = mock_ai_task_entity.mock_generate_data_tasks[0] assert task.instructions == "Test prompt" + + +async def test_generate_structured_data( + hass: HomeAssistant, + init_components: None, + mock_config_entry: MockConfigEntry, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test the entity can generate structured data.""" + result = await async_generate_data( + hass, + task_name="Test task", + entity_id=TEST_ENTITY_ID, + instructions="Please generate a profile for a new user", + structure=vol.Schema( + { + vol.Required("name"): selector.TextSelector(), + vol.Optional("age"): selector.NumberSelector( + config=selector.NumberSelectorConfig( + min=0, + max=120, + ) + ), + } + ), + ) + # Arbitrary data returned by the mock entity (not determined by above schema in test) + assert result.data == { + "name": "Tracy Chen", + "age": 30, + } + + assert mock_ai_task_entity.mock_generate_data_tasks + task = mock_ai_task_entity.mock_generate_data_tasks[0] + assert task.instructions == "Please generate a profile for a new user" + assert task.structure + assert isinstance(task.structure, vol.Schema) diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index fdfaaccd0a4..d32b09adec5 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -1,13 +1,17 @@ """Test initialization of the AI Task component.""" +from typing import Any + from freezegun.api import FrozenDateTimeFactory import pytest +import voluptuous as vol from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.core import HomeAssistant +from homeassistant.helpers import selector -from .conftest import TEST_ENTITY_ID +from .conftest import TEST_ENTITY_ID, MockAITaskEntity from tests.common import flush_store @@ -82,3 +86,160 @@ async def test_generate_data_service( ) assert result["data"] == "Mock result" + + +async def test_generate_data_service_structure_fields( + hass: HomeAssistant, + init_components: None, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test the entity can generate structured data with a top level object schema.""" + result = await hass.services.async_call( + "ai_task", + "generate_data", + { + "task_name": "Profile Generation", + "instructions": "Please generate a profile for a new user", + "entity_id": TEST_ENTITY_ID, + "structure": { + "name": { + "description": "First and last name of the user such as Alice Smith", + "required": True, + "selector": {"text": {}}, + }, + "age": { + "description": "Age of the user", + "selector": { + "number": { + "min": 0, + "max": 120, + } + }, + }, + }, + }, + blocking=True, + return_response=True, + ) + # Arbitrary data returned by the mock entity (not determined by above schema in test) + assert result["data"] == { + "name": "Tracy Chen", + "age": 30, + } + + assert mock_ai_task_entity.mock_generate_data_tasks + task = mock_ai_task_entity.mock_generate_data_tasks[0] + assert task.instructions == "Please generate a profile for a new user" + assert task.structure + assert isinstance(task.structure, vol.Schema) + schema = list(task.structure.schema.items()) + assert len(schema) == 2 + + name_key, name_value = schema[0] + assert name_key == "name" + assert isinstance(name_key, vol.Required) + assert name_key.description == "First and last name of the user such as Alice Smith" + assert isinstance(name_value, selector.TextSelector) + + age_key, age_value = schema[1] + assert age_key == "age" + assert isinstance(age_key, vol.Optional) + assert age_key.description == "Age of the user" + assert isinstance(age_value, selector.NumberSelector) + assert age_value.config["min"] == 0 + assert age_value.config["max"] == 120 + + +@pytest.mark.parametrize( + ("structure", "expected_exception", "expected_error"), + [ + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + "selector": {"invalid-selector": {}}, + }, + }, + vol.Invalid, + r"Unknown selector type invalid-selector.*", + ), + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + "selector": { + "text": { + "extra-config": False, + } + }, + }, + }, + vol.Invalid, + r"extra keys not allowed.*", + ), + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + }, + }, + vol.Invalid, + r"required key not provided.*selector.*", + ), + (12345, vol.Invalid, r"xpected a dictionary.*"), + ("name", vol.Invalid, r"xpected a dictionary.*"), + (["name"], vol.Invalid, r"xpected a dictionary.*"), + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + "selector": {"text": {}}, + "extra-fields": "Some extra fields", + }, + }, + vol.Invalid, + r"extra keys not allowed .*", + ), + ( + { + "name": { + "description": "First and last name of the user such as Alice Smith", + "selector": "invalid-schema", + }, + }, + vol.Invalid, + r"xpected a dictionary for dictionary.", + ), + ], + ids=( + "invalid-selector", + "invalid-selector-config", + "missing-selector", + "structure-is-int-not-object", + "structure-is-str-not-object", + "structure-is-list-not-object", + "extra-fields", + "invalid-selector-schema", + ), +) +async def test_generate_data_service_invalid_structure( + hass: HomeAssistant, + init_components: None, + structure: Any, + expected_exception: Exception, + expected_error: str, +) -> None: + """Test the entity can generate structured data.""" + with pytest.raises(expected_exception, match=expected_error): + await hass.services.async_call( + "ai_task", + "generate_data", + { + "task_name": "Profile Generation", + "instructions": "Please generate a profile for a new user", + "entity_id": TEST_ENTITY_ID, + "structure": structure, + }, + blocking=True, + return_response=True, + )