Add AI task structured output (#148083)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-07-04 06:03:34 -07:00 committed by GitHub
parent 99d63c49bb
commit b3d9908cd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 262 additions and 5 deletions

View File

@ -1,11 +1,12 @@
"""Integration to offer AI tasks to Home Assistant."""
import logging
from typing import Any
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
HassJobType,
HomeAssistant,
@ -14,12 +15,14 @@ from homeassistant.core import (
SupportsResponse,
callback,
)
from homeassistant.helpers import config_validation as cv, storage
from homeassistant.helpers import config_validation as cv, selector, storage
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import (
ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_PREFERENCES,
@ -47,6 +50,27 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
STRUCTURE_FIELD_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DESCRIPTION): str,
vol.Optional(ATTR_REQUIRED): bool,
vol.Required(CONF_SELECTOR): selector.validate_selector,
}
)
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure fields as a voluptuous Schema."""
if not isinstance(value, dict):
raise vol.Invalid("Structure must be a dictionary")
fields = {}
for k, v in value.items():
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
v[CONF_SELECTOR]
)
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
@ -64,6 +88,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Required(ATTR_TASK_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_STRUCTURE): vol.All(
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields,
),
}
),
supports_response=SupportsResponse.ONLY,

View File

@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"
DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks."

View File

@ -17,3 +17,9 @@ generate_data:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_DATA
structure:
advanced: true
required: false
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:

View File

@ -15,6 +15,10 @@
"entity_id": {
"name": "Entity ID",
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
},
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
}
}
}

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import voluptuous as vol
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -17,6 +19,7 @@ async def async_generate_data(
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
@ -38,6 +41,7 @@ async def async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
)
)
@ -52,6 +56,9 @@ class GenDataTask:
instructions: str
"""Instructions on what needs to be done."""
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"

View File

@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
def _base_components() -> dict[str, ModuleType]:
"""Return a cached lookup of base components."""
from homeassistant.components import ( # noqa: PLC0415
ai_task,
alarm_control_panel,
assist_satellite,
calendar,
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
)
return {
"ai_task": ai_task,
"alarm_control_panel": alarm_control_panel,
"assist_satellite": assist_satellite,
"calendar": calendar,

View File

@ -1,5 +1,7 @@
"""Test helpers for AI Task integration."""
import json
import pytest
from homeassistant.components.ai_task import (
@ -45,12 +47,18 @@ class MockAITaskEntity(AITaskEntity):
) -> GenDataTaskResult:
"""Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task)
if task.structure is not None:
data = {"name": "Tracy Chen", "age": 30}
data_chat_log = json.dumps(data)
else:
data = "Mock result"
data_chat_log = data
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result")
AssistantContent(self.entity_id, data_chat_log)
)
return GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data="Mock result",
data=data,
)

View File

@ -1,10 +1,12 @@
"""Tests for the AI Task entity model."""
from freezegun import freeze_time
import voluptuous as vol
from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@ -37,3 +39,40 @@ async def test_state_generate_data(
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt"
async def test_generate_structured_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Please generate a profile for a new user",
structure=vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
)
),
}
),
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result.data == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)

View File

@ -1,13 +1,17 @@
"""Test initialization of the AI Task component."""
from typing import Any
from freezegun.api import FrozenDateTimeFactory
import pytest
import voluptuous as vol
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.common import flush_store
@ -82,3 +86,160 @@ async def test_generate_data_service(
)
assert result["data"] == "Mock result"
async def test_generate_data_service_structure_fields(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data with a top level object schema."""
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": {
"name": {
"description": "First and last name of the user such as Alice Smith",
"required": True,
"selector": {"text": {}},
},
"age": {
"description": "Age of the user",
"selector": {
"number": {
"min": 0,
"max": 120,
}
},
},
},
},
blocking=True,
return_response=True,
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result["data"] == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)
schema = list(task.structure.schema.items())
assert len(schema) == 2
name_key, name_value = schema[0]
assert name_key == "name"
assert isinstance(name_key, vol.Required)
assert name_key.description == "First and last name of the user such as Alice Smith"
assert isinstance(name_value, selector.TextSelector)
age_key, age_value = schema[1]
assert age_key == "age"
assert isinstance(age_key, vol.Optional)
assert age_key.description == "Age of the user"
assert isinstance(age_value, selector.NumberSelector)
assert age_value.config["min"] == 0
assert age_value.config["max"] == 120
@pytest.mark.parametrize(
("structure", "expected_exception", "expected_error"),
[
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"invalid-selector": {}},
},
},
vol.Invalid,
r"Unknown selector type invalid-selector.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {
"text": {
"extra-config": False,
}
},
},
},
vol.Invalid,
r"extra keys not allowed.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
},
},
vol.Invalid,
r"required key not provided.*selector.*",
),
(12345, vol.Invalid, r"xpected a dictionary.*"),
("name", vol.Invalid, r"xpected a dictionary.*"),
(["name"], vol.Invalid, r"xpected a dictionary.*"),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"text": {}},
"extra-fields": "Some extra fields",
},
},
vol.Invalid,
r"extra keys not allowed .*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": "invalid-schema",
},
},
vol.Invalid,
r"xpected a dictionary for dictionary.",
),
],
ids=(
"invalid-selector",
"invalid-selector-config",
"missing-selector",
"structure-is-int-not-object",
"structure-is-str-not-object",
"structure-is-list-not-object",
"extra-fields",
"invalid-selector-schema",
),
)
async def test_generate_data_service_invalid_structure(
hass: HomeAssistant,
init_components: None,
structure: Any,
expected_exception: Exception,
expected_error: str,
) -> None:
"""Test the entity can generate structured data."""
with pytest.raises(expected_exception, match=expected_error):
await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": structure,
},
blocking=True,
return_response=True,
)