mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 01:07:10 +00:00
Add AI task structured output (#148083)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
99d63c49bb
commit
b3d9908cd9
@ -1,11 +1,12 @@
|
||||
"""Integration to offer AI tasks to Home Assistant."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
|
||||
from homeassistant.core import (
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
@ -14,12 +15,14 @@ from homeassistant.core import (
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv, storage
|
||||
from homeassistant.helpers import config_validation as cv, selector, storage
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
from .const import (
|
||||
ATTR_INSTRUCTIONS,
|
||||
ATTR_REQUIRED,
|
||||
ATTR_STRUCTURE,
|
||||
ATTR_TASK_NAME,
|
||||
DATA_COMPONENT,
|
||||
DATA_PREFERENCES,
|
||||
@ -47,6 +50,27 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
STRUCTURE_FIELD_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_DESCRIPTION): str,
|
||||
vol.Optional(ATTR_REQUIRED): bool,
|
||||
vol.Required(CONF_SELECTOR): selector.validate_selector,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
|
||||
"""Validate the structure fields as a voluptuous Schema."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("Structure must be a dictionary")
|
||||
fields = {}
|
||||
for k, v in value.items():
|
||||
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
|
||||
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
|
||||
v[CONF_SELECTOR]
|
||||
)
|
||||
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
@ -64,6 +88,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||
vol.Optional(ATTR_STRUCTURE): vol.All(
|
||||
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
||||
_validate_structure_fields,
|
||||
),
|
||||
}
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
|
@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data"
|
||||
|
||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||
ATTR_TASK_NAME: Final = "task_name"
|
||||
ATTR_STRUCTURE: Final = "structure"
|
||||
ATTR_REQUIRED: Final = "required"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a Home Assistant expert and help users with their tasks."
|
||||
|
@ -17,3 +17,9 @@ generate_data:
|
||||
domain: ai_task
|
||||
supported_features:
|
||||
- ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
structure:
|
||||
advanced: true
|
||||
required: false
|
||||
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
|
||||
selector:
|
||||
object:
|
||||
|
@ -15,6 +15,10 @@
|
||||
"entity_id": {
|
||||
"name": "Entity ID",
|
||||
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
|
||||
},
|
||||
"structure": {
|
||||
"name": "Structured output",
|
||||
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
@ -17,6 +19,7 @@ async def async_generate_data(
|
||||
task_name: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
@ -38,6 +41,7 @@ async def async_generate_data(
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
)
|
||||
)
|
||||
|
||||
@ -52,6 +56,9 @@ class GenDataTask:
|
||||
instructions: str
|
||||
"""Instructions on what needs to be done."""
|
||||
|
||||
structure: vol.Schema | None = None
|
||||
"""Optional structure for the data to be generated."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||
|
@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
|
||||
def _base_components() -> dict[str, ModuleType]:
|
||||
"""Return a cached lookup of base components."""
|
||||
from homeassistant.components import ( # noqa: PLC0415
|
||||
ai_task,
|
||||
alarm_control_panel,
|
||||
assist_satellite,
|
||||
calendar,
|
||||
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
|
||||
)
|
||||
|
||||
return {
|
||||
"ai_task": ai_task,
|
||||
"alarm_control_panel": alarm_control_panel,
|
||||
"assist_satellite": assist_satellite,
|
||||
"calendar": calendar,
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test helpers for AI Task integration."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.ai_task import (
|
||||
@ -45,12 +47,18 @@ class MockAITaskEntity(AITaskEntity):
|
||||
) -> GenDataTaskResult:
|
||||
"""Mock handling of generate data task."""
|
||||
self.mock_generate_data_tasks.append(task)
|
||||
if task.structure is not None:
|
||||
data = {"name": "Tracy Chen", "age": 30}
|
||||
data_chat_log = json.dumps(data)
|
||||
else:
|
||||
data = "Mock result"
|
||||
data_chat_log = data
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(self.entity_id, "Mock result")
|
||||
AssistantContent(self.entity_id, data_chat_log)
|
||||
)
|
||||
return GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data="Mock result",
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Tests for the AI Task entity model."""
|
||||
|
||||
from freezegun import freeze_time
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.ai_task import async_generate_data
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
|
||||
@ -37,3 +39,40 @@ async def test_state_generate_data(
|
||||
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
assert task.instructions == "Test prompt"
|
||||
|
||||
|
||||
async def test_generate_structured_data(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the entity can generate structured data."""
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Please generate a profile for a new user",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("name"): selector.TextSelector(),
|
||||
vol.Optional("age"): selector.NumberSelector(
|
||||
config=selector.NumberSelectorConfig(
|
||||
min=0,
|
||||
max=120,
|
||||
)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
# Arbitrary data returned by the mock entity (not determined by above schema in test)
|
||||
assert result.data == {
|
||||
"name": "Tracy Chen",
|
||||
"age": 30,
|
||||
}
|
||||
|
||||
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
assert task.instructions == "Please generate a profile for a new user"
|
||||
assert task.structure
|
||||
assert isinstance(task.structure, vol.Schema)
|
||||
|
@ -1,13 +1,17 @@
|
||||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .conftest import TEST_ENTITY_ID
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
@ -82,3 +86,160 @@ async def test_generate_data_service(
|
||||
)
|
||||
|
||||
assert result["data"] == "Mock result"
|
||||
|
||||
|
||||
async def test_generate_data_service_structure_fields(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the entity can generate structured data with a top level object schema."""
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_data",
|
||||
{
|
||||
"task_name": "Profile Generation",
|
||||
"instructions": "Please generate a profile for a new user",
|
||||
"entity_id": TEST_ENTITY_ID,
|
||||
"structure": {
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
"required": True,
|
||||
"selector": {"text": {}},
|
||||
},
|
||||
"age": {
|
||||
"description": "Age of the user",
|
||||
"selector": {
|
||||
"number": {
|
||||
"min": 0,
|
||||
"max": 120,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
# Arbitrary data returned by the mock entity (not determined by above schema in test)
|
||||
assert result["data"] == {
|
||||
"name": "Tracy Chen",
|
||||
"age": 30,
|
||||
}
|
||||
|
||||
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
assert task.instructions == "Please generate a profile for a new user"
|
||||
assert task.structure
|
||||
assert isinstance(task.structure, vol.Schema)
|
||||
schema = list(task.structure.schema.items())
|
||||
assert len(schema) == 2
|
||||
|
||||
name_key, name_value = schema[0]
|
||||
assert name_key == "name"
|
||||
assert isinstance(name_key, vol.Required)
|
||||
assert name_key.description == "First and last name of the user such as Alice Smith"
|
||||
assert isinstance(name_value, selector.TextSelector)
|
||||
|
||||
age_key, age_value = schema[1]
|
||||
assert age_key == "age"
|
||||
assert isinstance(age_key, vol.Optional)
|
||||
assert age_key.description == "Age of the user"
|
||||
assert isinstance(age_value, selector.NumberSelector)
|
||||
assert age_value.config["min"] == 0
|
||||
assert age_value.config["max"] == 120
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("structure", "expected_exception", "expected_error"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
"selector": {"invalid-selector": {}},
|
||||
},
|
||||
},
|
||||
vol.Invalid,
|
||||
r"Unknown selector type invalid-selector.*",
|
||||
),
|
||||
(
|
||||
{
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
"selector": {
|
||||
"text": {
|
||||
"extra-config": False,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
vol.Invalid,
|
||||
r"extra keys not allowed.*",
|
||||
),
|
||||
(
|
||||
{
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
},
|
||||
},
|
||||
vol.Invalid,
|
||||
r"required key not provided.*selector.*",
|
||||
),
|
||||
(12345, vol.Invalid, r"xpected a dictionary.*"),
|
||||
("name", vol.Invalid, r"xpected a dictionary.*"),
|
||||
(["name"], vol.Invalid, r"xpected a dictionary.*"),
|
||||
(
|
||||
{
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
"selector": {"text": {}},
|
||||
"extra-fields": "Some extra fields",
|
||||
},
|
||||
},
|
||||
vol.Invalid,
|
||||
r"extra keys not allowed .*",
|
||||
),
|
||||
(
|
||||
{
|
||||
"name": {
|
||||
"description": "First and last name of the user such as Alice Smith",
|
||||
"selector": "invalid-schema",
|
||||
},
|
||||
},
|
||||
vol.Invalid,
|
||||
r"xpected a dictionary for dictionary.",
|
||||
),
|
||||
],
|
||||
ids=(
|
||||
"invalid-selector",
|
||||
"invalid-selector-config",
|
||||
"missing-selector",
|
||||
"structure-is-int-not-object",
|
||||
"structure-is-str-not-object",
|
||||
"structure-is-list-not-object",
|
||||
"extra-fields",
|
||||
"invalid-selector-schema",
|
||||
),
|
||||
)
|
||||
async def test_generate_data_service_invalid_structure(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
structure: Any,
|
||||
expected_exception: Exception,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test the entity can generate structured data."""
|
||||
with pytest.raises(expected_exception, match=expected_error):
|
||||
await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_data",
|
||||
{
|
||||
"task_name": "Profile Generation",
|
||||
"instructions": "Please generate a profile for a new user",
|
||||
"entity_id": TEST_ENTITY_ID,
|
||||
"structure": structure,
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user