Add AI task structured output

This commit is contained in:
Allen Porter 2025-07-03 21:26:20 +00:00
parent 8330ae2d3a
commit 789eb029fa
10 changed files with 305 additions and 5 deletions

View File

@ -1,11 +1,12 @@
"""Integration to offer AI tasks to Home Assistant.""" """Integration to offer AI tasks to Home Assistant."""
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import ( from homeassistant.core import (
HassJobType, HassJobType,
HomeAssistant, HomeAssistant,
@ -14,12 +15,14 @@ from homeassistant.core import (
SupportsResponse, SupportsResponse,
callback, callback,
) )
from homeassistant.helpers import config_validation as cv, storage from homeassistant.helpers import config_validation as cv, selector, storage
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import ( from .const import (
ATTR_INSTRUCTIONS, ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
ATTR_TASK_NAME, ATTR_TASK_NAME,
DATA_COMPONENT, DATA_COMPONENT,
DATA_PREFERENCES, DATA_PREFERENCES,
@ -47,6 +50,29 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
STRUCTURE_FIELD_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DESCRIPTION): str,
vol.Optional(ATTR_REQUIRED): bool,
vol.Required(CONF_SELECTOR): selector.validate_selector,
}
)
def _validate_structure(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure for the generate data task."""
if not isinstance(value, dict):
raise vol.Invalid("Structure must be a dictionary")
fields = {}
for k, v in value.items():
if not isinstance(v, dict):
raise vol.Invalid(f"Structure field '{k}' must be a dictionary")
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
v[CONF_SELECTOR]
)
return vol.Schema(fields)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
@ -64,6 +90,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Required(ATTR_TASK_NAME): cv.string, vol.Required(ATTR_TASK_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_id, vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string, vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_STRUCTURE): vol.All(
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure,
),
} }
), ),
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,

View File

@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions" ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name" ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks." "You are a Home Assistant expert and help users with their tasks."
@ -32,3 +34,6 @@ class AITaskEntityFeature(IntFlag):
GENERATE_DATA = 1 GENERATE_DATA = 1
"""Generate data based on instructions.""" """Generate data based on instructions."""
GENERATE_STRUCTURED_DATA = 2
"""Generate structured data based on instructions."""

View File

@ -17,3 +17,12 @@ generate_data:
domain: ai_task domain: ai_task
supported_features: supported_features:
- ai_task.AITaskEntityFeature.GENERATE_DATA - ai_task.AITaskEntityFeature.GENERATE_DATA
structure:
advanced: true
required: false
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:
filter:
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_STRUCTURED_DATA

View File

@ -15,6 +15,10 @@
"entity_id": { "entity_id": {
"name": "Entity ID", "name": "Entity ID",
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used." "description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
},
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
} }
} }
} }

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import voluptuous as vol
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -17,6 +19,7 @@ async def async_generate_data(
task_name: str, task_name: str,
entity_id: str | None = None, entity_id: str | None = None,
instructions: str, instructions: str,
structure: vol.Schema | None = None,
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Run a task in the AI Task integration.""" """Run a task in the AI Task integration."""
if entity_id is None: if entity_id is None:
@ -34,10 +37,20 @@ async def async_generate_data(
f"AI Task entity {entity_id} does not support generating data" f"AI Task entity {entity_id} does not support generating data"
) )
if structure is not None:
if (
AITaskEntityFeature.GENERATE_STRUCTURED_DATA
not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating structured data"
)
return await entity.internal_async_generate_data( return await entity.internal_async_generate_data(
GenDataTask( GenDataTask(
name=task_name, name=task_name,
instructions=instructions, instructions=instructions,
structure=structure,
) )
) )
@ -52,6 +65,9 @@ class GenDataTask:
instructions: str instructions: str
"""Instructions on what needs to be done.""" """Instructions on what needs to be done."""
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
def __str__(self) -> str: def __str__(self) -> str:
"""Return task as a string.""" """Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>" return f"<GenDataTask {self.name}: {id(self)}>"

View File

@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
def _base_components() -> dict[str, ModuleType]: def _base_components() -> dict[str, ModuleType]:
"""Return a cached lookup of base components.""" """Return a cached lookup of base components."""
from homeassistant.components import ( # noqa: PLC0415 from homeassistant.components import ( # noqa: PLC0415
ai_task,
alarm_control_panel, alarm_control_panel,
assist_satellite, assist_satellite,
calendar, calendar,
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
) )
return { return {
"ai_task": ai_task,
"alarm_control_panel": alarm_control_panel, "alarm_control_panel": alarm_control_panel,
"assist_satellite": assist_satellite, "assist_satellite": assist_satellite,
"calendar": calendar, "calendar": calendar,

View File

@ -1,5 +1,7 @@
"""Test helpers for AI Task integration.""" """Test helpers for AI Task integration."""
import json
import pytest import pytest
from homeassistant.components.ai_task import ( from homeassistant.components.ai_task import (
@ -33,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing.""" """Mock AI Task entity for testing."""
_attr_name = "Test Task Entity" _attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA _attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.GENERATE_STRUCTURED_DATA
)
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the mock entity.""" """Initialize the mock entity."""
@ -42,6 +46,25 @@ class MockAITaskEntity(AITaskEntity):
async def _async_generate_data( async def _async_generate_data(
self, task: GenDataTask, chat_log: ChatLog self, task: GenDataTask, chat_log: ChatLog
) -> GenDataTaskResult:
"""Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task)
if task.structure is not None:
data = {"name": "Tracy Chen", "age": 30}
data_chat_log = json.dumps(data)
else:
data = "Mock result"
data_chat_log = data
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, data_chat_log)
)
return GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
async def _async_generate_structured_data(
self, task: GenDataTask, chat_log: ChatLog, structure: dict[str, dict]
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Mock handling of generate data task.""" """Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task) self.mock_generate_data_tasks.append(task)

View File

@ -1,10 +1,12 @@
"""Tests for the AI Task entity model.""" """Tests for the AI Task entity model."""
from freezegun import freeze_time from freezegun import freeze_time
import voluptuous as vol
from homeassistant.components.ai_task import async_generate_data from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID, MockAITaskEntity from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@ -37,3 +39,40 @@ async def test_state_generate_data(
assert mock_ai_task_entity.mock_generate_data_tasks assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0] task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt" assert task.instructions == "Test prompt"
async def test_generate_structured_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Please generate a profile for a new user",
structure=vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
)
),
}
),
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result.data == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)

View File

@ -1,13 +1,17 @@
"""Test initialization of the AI Task component.""" """Test initialization of the AI Task component."""
from typing import Any
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
import voluptuous as vol
from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.common import flush_store from tests.common import flush_store
@ -82,3 +86,149 @@ async def test_generate_data_service(
) )
assert result["data"] == "Mock result" assert result["data"] == "Mock result"
async def test_generate_data_service_structure(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": {
"name": {
"description": "First and last name of the user such as Alice Smith",
"required": True,
"selector": {"text": {}},
},
"age": {
"description": "Age of the user",
"selector": {
"number": {
"min": 0,
"max": 120,
}
},
},
},
},
blocking=True,
return_response=True,
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result["data"] == {
"name": "Tracy Chen",
"age": 30,
}
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)
schema = list(task.structure.schema.items())
assert len(schema) == 2
name_key, name_value = schema[0]
assert name_key == "name"
assert isinstance(name_key, vol.Required)
assert name_key.description == "First and last name of the user such as Alice Smith"
assert isinstance(name_value, selector.TextSelector)
age_key, age_value = schema[1]
assert age_key == "age"
assert isinstance(age_key, vol.Optional)
assert age_key.description == "Age of the user"
assert isinstance(age_value, selector.NumberSelector)
assert age_value.config["min"] == 0
assert age_value.config["max"] == 120
@pytest.mark.parametrize(
("structure", "expected_exception", "expected_error"),
[
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"invalid-selector": {}},
},
},
vol.Invalid,
r"Unknown selector type invalid-selector.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {
"text": {
"extra-config": False,
}
},
},
},
vol.Invalid,
r"extra keys not allowed.*",
),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
},
},
vol.Invalid,
r"required key not provided.*selector.*",
),
(12345, vol.Invalid, r"expected a dictionary.*"),
("name", vol.Invalid, r"expected a dictionary.*"),
(["name"], vol.Invalid, r"expected a dictionary.*"),
(
{
"name": {
"description": "First and last name of the user such as Alice Smith",
"selector": {"text": {}},
"extra-fields": "Some extra fields",
},
},
vol.Invalid,
r"extra keys not allowed .*",
),
],
ids=(
"invalid-selector",
"invalid-selector-config",
"missing-selector",
"structure-is-int-not-object",
"structure-is-str-not-object",
"structure-is-list-not-object",
"extra-fields",
),
)
async def test_generate_data_service_invalid_structure(
hass: HomeAssistant,
init_components: None,
structure: Any,
expected_exception: Exception,
expected_error: str,
) -> None:
"""Test the entity can generate structured data."""
with pytest.raises(expected_exception, match=expected_error):
await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Profile Generation",
"instructions": "Please generate a profile for a new user",
"entity_id": TEST_ENTITY_ID,
"structure": structure,
},
blocking=True,
return_response=True,
)

View File

@ -3,13 +3,14 @@
from freezegun import freeze_time from freezegun import freeze_time
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session from homeassistant.helpers import chat_session, selector
from .conftest import TEST_ENTITY_ID, MockAITaskEntity from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@ -127,3 +128,24 @@ async def test_run_data_task_updates_chat_log(
async_get_chat_log(hass, session) as chat_log, async_get_chat_log(hass, session) as chat_log,
): ):
assert chat_log.content == snapshot assert chat_log.content == snapshot
async def test_run_task_structure_unsupported_feature(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test running a task with an unknown entity."""
mock_ai_task_entity.supported_features = AITaskEntityFeature.GENERATE_DATA
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating structured data",
):
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
entity_id=TEST_ENTITY_ID,
structure=vol.Schema({vol.Required("name"): selector.TextSelector()}),
)