mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Add AI task structured output (#148083)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
99d63c49bb
commit
b3d9908cd9
@ -1,11 +1,12 @@
|
|||||||
"""Integration to offer AI tasks to Home Assistant."""
|
"""Integration to offer AI tasks to Home Assistant."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import ATTR_ENTITY_ID
|
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
HassJobType,
|
HassJobType,
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
@ -14,12 +15,14 @@ from homeassistant.core import (
|
|||||||
SupportsResponse,
|
SupportsResponse,
|
||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import config_validation as cv, storage
|
from homeassistant.helpers import config_validation as cv, selector, storage
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_INSTRUCTIONS,
|
ATTR_INSTRUCTIONS,
|
||||||
|
ATTR_REQUIRED,
|
||||||
|
ATTR_STRUCTURE,
|
||||||
ATTR_TASK_NAME,
|
ATTR_TASK_NAME,
|
||||||
DATA_COMPONENT,
|
DATA_COMPONENT,
|
||||||
DATA_PREFERENCES,
|
DATA_PREFERENCES,
|
||||||
@ -47,6 +50,27 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
|
||||||
|
STRUCTURE_FIELD_SCHEMA = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional(CONF_DESCRIPTION): str,
|
||||||
|
vol.Optional(ATTR_REQUIRED): bool,
|
||||||
|
vol.Required(CONF_SELECTOR): selector.validate_selector,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
|
||||||
|
"""Validate the structure fields as a voluptuous Schema."""
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise vol.Invalid("Structure must be a dictionary")
|
||||||
|
fields = {}
|
||||||
|
for k, v in value.items():
|
||||||
|
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
|
||||||
|
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
|
||||||
|
v[CONF_SELECTOR]
|
||||||
|
)
|
||||||
|
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
@ -64,6 +88,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
||||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||||
|
vol.Optional(ATTR_STRUCTURE): vol.All(
|
||||||
|
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
||||||
|
_validate_structure_fields,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
supports_response=SupportsResponse.ONLY,
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
@ -21,6 +21,8 @@ SERVICE_GENERATE_DATA = "generate_data"
|
|||||||
|
|
||||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||||
ATTR_TASK_NAME: Final = "task_name"
|
ATTR_TASK_NAME: Final = "task_name"
|
||||||
|
ATTR_STRUCTURE: Final = "structure"
|
||||||
|
ATTR_REQUIRED: Final = "required"
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = (
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
"You are a Home Assistant expert and help users with their tasks."
|
"You are a Home Assistant expert and help users with their tasks."
|
||||||
|
@ -17,3 +17,9 @@ generate_data:
|
|||||||
domain: ai_task
|
domain: ai_task
|
||||||
supported_features:
|
supported_features:
|
||||||
- ai_task.AITaskEntityFeature.GENERATE_DATA
|
- ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
structure:
|
||||||
|
advanced: true
|
||||||
|
required: false
|
||||||
|
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
|
||||||
|
selector:
|
||||||
|
object:
|
||||||
|
@ -15,6 +15,10 @@
|
|||||||
"entity_id": {
|
"entity_id": {
|
||||||
"name": "Entity ID",
|
"name": "Entity ID",
|
||||||
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
|
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"name": "Structured output",
|
||||||
|
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
@ -17,6 +19,7 @@ async def async_generate_data(
|
|||||||
task_name: str,
|
task_name: str,
|
||||||
entity_id: str | None = None,
|
entity_id: str | None = None,
|
||||||
instructions: str,
|
instructions: str,
|
||||||
|
structure: vol.Schema | None = None,
|
||||||
) -> GenDataTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a task in the AI Task integration."""
|
"""Run a task in the AI Task integration."""
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
@ -38,6 +41,7 @@ async def async_generate_data(
|
|||||||
GenDataTask(
|
GenDataTask(
|
||||||
name=task_name,
|
name=task_name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
|
structure=structure,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,6 +56,9 @@ class GenDataTask:
|
|||||||
instructions: str
|
instructions: str
|
||||||
"""Instructions on what needs to be done."""
|
"""Instructions on what needs to be done."""
|
||||||
|
|
||||||
|
structure: vol.Schema | None = None
|
||||||
|
"""Optional structure for the data to be generated."""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return task as a string."""
|
"""Return task as a string."""
|
||||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||||
|
@ -86,6 +86,7 @@ ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
|
|||||||
def _base_components() -> dict[str, ModuleType]:
|
def _base_components() -> dict[str, ModuleType]:
|
||||||
"""Return a cached lookup of base components."""
|
"""Return a cached lookup of base components."""
|
||||||
from homeassistant.components import ( # noqa: PLC0415
|
from homeassistant.components import ( # noqa: PLC0415
|
||||||
|
ai_task,
|
||||||
alarm_control_panel,
|
alarm_control_panel,
|
||||||
assist_satellite,
|
assist_satellite,
|
||||||
calendar,
|
calendar,
|
||||||
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"ai_task": ai_task,
|
||||||
"alarm_control_panel": alarm_control_panel,
|
"alarm_control_panel": alarm_control_panel,
|
||||||
"assist_satellite": assist_satellite,
|
"assist_satellite": assist_satellite,
|
||||||
"calendar": calendar,
|
"calendar": calendar,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test helpers for AI Task integration."""
|
"""Test helpers for AI Task integration."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.ai_task import (
|
from homeassistant.components.ai_task import (
|
||||||
@ -45,12 +47,18 @@ class MockAITaskEntity(AITaskEntity):
|
|||||||
) -> GenDataTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Mock handling of generate data task."""
|
"""Mock handling of generate data task."""
|
||||||
self.mock_generate_data_tasks.append(task)
|
self.mock_generate_data_tasks.append(task)
|
||||||
|
if task.structure is not None:
|
||||||
|
data = {"name": "Tracy Chen", "age": 30}
|
||||||
|
data_chat_log = json.dumps(data)
|
||||||
|
else:
|
||||||
|
data = "Mock result"
|
||||||
|
data_chat_log = data
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(self.entity_id, "Mock result")
|
AssistantContent(self.entity_id, data_chat_log)
|
||||||
)
|
)
|
||||||
return GenDataTaskResult(
|
return GenDataTaskResult(
|
||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
data="Mock result",
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
"""Tests for the AI Task entity model."""
|
"""Tests for the AI Task entity model."""
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.ai_task import async_generate_data
|
from homeassistant.components.ai_task import async_generate_data
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import selector
|
||||||
|
|
||||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||||
|
|
||||||
@ -37,3 +39,40 @@ async def test_state_generate_data(
|
|||||||
assert mock_ai_task_entity.mock_generate_data_tasks
|
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
assert task.instructions == "Test prompt"
|
assert task.instructions == "Test prompt"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_structured_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test the entity can generate structured data."""
|
||||||
|
result = await async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test task",
|
||||||
|
entity_id=TEST_ENTITY_ID,
|
||||||
|
instructions="Please generate a profile for a new user",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("name"): selector.TextSelector(),
|
||||||
|
vol.Optional("age"): selector.NumberSelector(
|
||||||
|
config=selector.NumberSelectorConfig(
|
||||||
|
min=0,
|
||||||
|
max=120,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Arbitrary data returned by the mock entity (not determined by above schema in test)
|
||||||
|
assert result.data == {
|
||||||
|
"name": "Tracy Chen",
|
||||||
|
"age": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||||
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
|
assert task.instructions == "Please generate a profile for a new user"
|
||||||
|
assert task.structure
|
||||||
|
assert isinstance(task.structure, vol.Schema)
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
"""Test initialization of the AI Task component."""
|
"""Test initialization of the AI Task component."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.ai_task import AITaskPreferences
|
from homeassistant.components.ai_task import AITaskPreferences
|
||||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import selector
|
||||||
|
|
||||||
from .conftest import TEST_ENTITY_ID
|
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||||
|
|
||||||
from tests.common import flush_store
|
from tests.common import flush_store
|
||||||
|
|
||||||
@ -82,3 +86,160 @@ async def test_generate_data_service(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result["data"] == "Mock result"
|
assert result["data"] == "Mock result"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_data_service_structure_fields(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test the entity can generate structured data with a top level object schema."""
|
||||||
|
result = await hass.services.async_call(
|
||||||
|
"ai_task",
|
||||||
|
"generate_data",
|
||||||
|
{
|
||||||
|
"task_name": "Profile Generation",
|
||||||
|
"instructions": "Please generate a profile for a new user",
|
||||||
|
"entity_id": TEST_ENTITY_ID,
|
||||||
|
"structure": {
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"required": True,
|
||||||
|
"selector": {"text": {}},
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"description": "Age of the user",
|
||||||
|
"selector": {
|
||||||
|
"number": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 120,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
# Arbitrary data returned by the mock entity (not determined by above schema in test)
|
||||||
|
assert result["data"] == {
|
||||||
|
"name": "Tracy Chen",
|
||||||
|
"age": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||||
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
|
assert task.instructions == "Please generate a profile for a new user"
|
||||||
|
assert task.structure
|
||||||
|
assert isinstance(task.structure, vol.Schema)
|
||||||
|
schema = list(task.structure.schema.items())
|
||||||
|
assert len(schema) == 2
|
||||||
|
|
||||||
|
name_key, name_value = schema[0]
|
||||||
|
assert name_key == "name"
|
||||||
|
assert isinstance(name_key, vol.Required)
|
||||||
|
assert name_key.description == "First and last name of the user such as Alice Smith"
|
||||||
|
assert isinstance(name_value, selector.TextSelector)
|
||||||
|
|
||||||
|
age_key, age_value = schema[1]
|
||||||
|
assert age_key == "age"
|
||||||
|
assert isinstance(age_key, vol.Optional)
|
||||||
|
assert age_key.description == "Age of the user"
|
||||||
|
assert isinstance(age_value, selector.NumberSelector)
|
||||||
|
assert age_value.config["min"] == 0
|
||||||
|
assert age_value.config["max"] == 120
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("structure", "expected_exception", "expected_error"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"selector": {"invalid-selector": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"Unknown selector type invalid-selector.*",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"selector": {
|
||||||
|
"text": {
|
||||||
|
"extra-config": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"extra keys not allowed.*",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"required key not provided.*selector.*",
|
||||||
|
),
|
||||||
|
(12345, vol.Invalid, r"xpected a dictionary.*"),
|
||||||
|
("name", vol.Invalid, r"xpected a dictionary.*"),
|
||||||
|
(["name"], vol.Invalid, r"xpected a dictionary.*"),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"selector": {"text": {}},
|
||||||
|
"extra-fields": "Some extra fields",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"extra keys not allowed .*",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"name": {
|
||||||
|
"description": "First and last name of the user such as Alice Smith",
|
||||||
|
"selector": "invalid-schema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vol.Invalid,
|
||||||
|
r"xpected a dictionary for dictionary.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=(
|
||||||
|
"invalid-selector",
|
||||||
|
"invalid-selector-config",
|
||||||
|
"missing-selector",
|
||||||
|
"structure-is-int-not-object",
|
||||||
|
"structure-is-str-not-object",
|
||||||
|
"structure-is-list-not-object",
|
||||||
|
"extra-fields",
|
||||||
|
"invalid-selector-schema",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_generate_data_service_invalid_structure(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
structure: Any,
|
||||||
|
expected_exception: Exception,
|
||||||
|
expected_error: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test the entity can generate structured data."""
|
||||||
|
with pytest.raises(expected_exception, match=expected_error):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"ai_task",
|
||||||
|
"generate_data",
|
||||||
|
{
|
||||||
|
"task_name": "Profile Generation",
|
||||||
|
"instructions": "Please generate a profile for a new user",
|
||||||
|
"entity_id": TEST_ENTITY_ID,
|
||||||
|
"structure": structure,
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user