Add AI task structured output (#148083)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-07-04 06:03:34 -07:00 committed by GitHub
parent 99d63c49bb
commit b3d9908cd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 262 additions and 5 deletions

View File

@ -1,11 +1,12 @@
"""Integration to offer AI tasks to Home Assistant.""" """Integration to offer AI tasks to Home Assistant."""
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import ( from homeassistant.core import (
HassJobType, HassJobType,
HomeAssistant, HomeAssistant,
@ -14,12 +15,14 @@ from homeassistant.core import (
SupportsResponse, SupportsResponse,
callback, callback,
) )
from homeassistant.helpers import config_validation as cv, storage from homeassistant.helpers import config_validation as cv, selector, storage
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import ( from .const import (
ATTR_INSTRUCTIONS, ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
ATTR_TASK_NAME, ATTR_TASK_NAME,
DATA_COMPONENT, DATA_COMPONENT,
DATA_PREFERENCES, DATA_PREFERENCES,
@ -47,6 +50,27 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
STRUCTURE_FIELD_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DESCRIPTION): str,
vol.Optional(ATTR_REQUIRED): bool,
vol.Required(CONF_SELECTOR): selector.validate_selector,
}
)
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure fields as a voluptuous Schema."""
if not isinstance(value, dict):
raise vol.Invalid("Structure must be a dictionary")
fields = {}
for k, v in value.items():
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
v[CONF_SELECTOR]
)
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
@ -64,6 +88,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Required(ATTR_TASK_NAME): cv.string, vol.Required(ATTR_TASK_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_id, vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string, vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_STRUCTURE): vol.All(
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields,
),
} }
), ),
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,

View File

@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions" ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name" ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks." "You are a Home Assistant expert and help users with their tasks."

View File

@ -17,3 +17,9 @@ generate_data:
domain: ai_task domain: ai_task
supported_features: supported_features:
- ai_task.AITaskEntityFeature.GENERATE_DATA - ai_task.AITaskEntityFeature.GENERATE_DATA
structure:
advanced: true
required: false
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:

View File

@ -15,6 +15,10 @@
"entity_id": { "entity_id": {
"name": "Entity ID", "name": "Entity ID",
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used." "description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
},
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
} }
} }
} }

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import voluptuous as vol
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -17,6 +19,7 @@ async def async_generate_data(
task_name: str, task_name: str,
entity_id: str | None = None, entity_id: str | None = None,
instructions: str, instructions: str,
structure: vol.Schema | None = None,
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Run a task in the AI Task integration.""" """Run a task in the AI Task integration."""
if entity_id is None: if entity_id is None:
@ -38,6 +41,7 @@ async def async_generate_data(
GenDataTask( GenDataTask(
name=task_name, name=task_name,
instructions=instructions, instructions=instructions,
structure=structure,
) )
) )
@ -52,6 +56,9 @@ class GenDataTask:
instructions: str instructions: str
"""Instructions on what needs to be done.""" """Instructions on what needs to be done."""
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
def __str__(self) -> str: def __str__(self) -> str:
"""Return task as a string.""" """Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>" return f"<GenDataTask {self.name}: {id(self)}>"

View File

@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
def _base_components() -> dict[str, ModuleType]: def _base_components() -> dict[str, ModuleType]:
"""Return a cached lookup of base components.""" """Return a cached lookup of base components."""
from homeassistant.components import ( # noqa: PLC0415 from homeassistant.components import ( # noqa: PLC0415
ai_task,
alarm_control_panel, alarm_control_panel,
assist_satellite, assist_satellite,
calendar, calendar,
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
) )
return { return {
"ai_task": ai_task,
"alarm_control_panel": alarm_control_panel, "alarm_control_panel": alarm_control_panel,
"assist_satellite": assist_satellite, "assist_satellite": assist_satellite,
"calendar": calendar, "calendar": calendar,

View File

@ -1,5 +1,7 @@
"""Test helpers for AI Task integration.""" """Test helpers for AI Task integration."""
import json
import pytest import pytest
from homeassistant.components.ai_task import ( from homeassistant.components.ai_task import (
@ -45,12 +47,18 @@ class MockAITaskEntity(AITaskEntity):
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Mock handling of generate data task.""" """Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task) self.mock_generate_data_tasks.append(task)
if task.structure is not None:
data = {"name": "Tracy Chen", "age": 30}
data_chat_log = json.dumps(data)
else:
data = "Mock result"
data_chat_log = data
chat_log.async_add_assistant_content_without_tools( chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result") AssistantContent(self.entity_id, data_chat_log)
) )
return GenDataTaskResult( return GenDataTaskResult(
conversation_id=chat_log.conversation_id, conversation_id=chat_log.conversation_id,
data="Mock result", data=data,
) )

View File

@ -1,10 +1,12 @@
"""Tests for the AI Task entity model.""" """Tests for the AI Task entity model."""
from freezegun import freeze_time from freezegun import freeze_time
import voluptuous as vol
from homeassistant.components.ai_task import async_generate_data from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID, MockAITaskEntity from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@ -37,3 +39,40 @@ async def test_state_generate_data(
assert mock_ai_task_entity.mock_generate_data_tasks assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0] task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt" assert task.instructions == "Test prompt"
async def test_generate_structured_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Please generate a profile for a new user",
structure=vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
)
),
}
),
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result.data == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)

View File

@ -1,13 +1,17 @@
"""Test initialization of the AI Task component.""" """Test initialization of the AI Task component."""
from typing import Any
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
import voluptuous as vol
from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.common import flush_store from tests.common import flush_store
@ -82,3 +86,160 @@ async def test_generate_data_service(
) )
assert result["data"] == "Mock result" assert result["data"] == "Mock result"
async def test_generate_data_service_structure_fields(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data with a top level object schema."""
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": {
"name": {
"description": "First and last name of the user such as Alice Smith",
"required": True,
"selector": {"text": {}},
},
"age": {
"description": "Age of the user",
"selector": {
"number": {
"min": 0,
"max": 120,
}
},
},
},
},
blocking=True,
return_response=True,
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result["data"] == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)
schema = list(task.structure.schema.items())
assert len(schema) == 2
name_key, name_value = schema[0]
assert name_key == "name"
assert isinstance(name_key, vol.Required)
assert name_key.description == "First and last name of the user such as Alice Smith"
assert isinstance(name_value, selector.TextSelector)
age_key, age_value = schema[1]
assert age_key == "age"
assert isinstance(age_key, vol.Optional)
assert age_key.description == "Age of the user"
assert isinstance(age_value, selector.NumberSelector)
assert age_value.config["min"] == 0
assert age_value.config["max"] == 120
@pytest.mark.parametrize(
("structure", "expected_exception", "expected_error"),
[
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"invalid-selector": {}},
},
},
vol.Invalid,
r"Unknown selector type invalid-selector.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {
"text": {
"extra-config": False,
}
},
},
},
vol.Invalid,
r"extra keys not allowed.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
},
},
vol.Invalid,
r"required key not provided.*selector.*",
),
(12345, vol.Invalid, r"xpected a dictionary.*"),
("name", vol.Invalid, r"xpected a dictionary.*"),
(["name"], vol.Invalid, r"xpected a dictionary.*"),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"text": {}},
"extra-fields": "Some extra fields",
},
},
vol.Invalid,
r"extra keys not allowed .*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": "invalid-schema",
},
},
vol.Invalid,
r"xpected a dictionary for dictionary.",
),
],
ids=(
"invalid-selector",
"invalid-selector-config",
"missing-selector",
"structure-is-int-not-object",
"structure-is-str-not-object",
"structure-is-list-not-object",
"extra-fields",
"invalid-selector-schema",
),
)
async def test_generate_data_service_invalid_structure(
hass: HomeAssistant,
init_components: None,
structure: Any,
expected_exception: Exception,
expected_error: str,
) -> None:
"""Test the entity can generate structured data."""
with pytest.raises(expected_exception, match=expected_error):
await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": structure,
},
blocking=True,
return_response=True,
)