mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 05:47:10 +00:00
Add attachment support to AI task (#148120)
This commit is contained in:
parent
699c60f293
commit
008e2a3d10
@ -20,6 +20,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
ATTR_ATTACHMENTS,
|
||||||
ATTR_INSTRUCTIONS,
|
ATTR_INSTRUCTIONS,
|
||||||
ATTR_REQUIRED,
|
ATTR_REQUIRED,
|
||||||
ATTR_STRUCTURE,
|
ATTR_STRUCTURE,
|
||||||
@ -32,7 +33,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import AITaskEntity
|
from .entity import AITaskEntity
|
||||||
from .http import async_setup as async_setup_http
|
from .http import async_setup as async_setup_http
|
||||||
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
@ -40,6 +41,7 @@ __all__ = [
|
|||||||
"AITaskEntityFeature",
|
"AITaskEntityFeature",
|
||||||
"GenDataTask",
|
"GenDataTask",
|
||||||
"GenDataTaskResult",
|
"GenDataTaskResult",
|
||||||
|
"PlayMediaWithId",
|
||||||
"async_generate_data",
|
"async_generate_data",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_setup_entry",
|
"async_setup_entry",
|
||||||
@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
||||||
_validate_structure_fields,
|
_validate_structure_fields,
|
||||||
),
|
),
|
||||||
|
vol.Optional(ATTR_ATTACHMENTS): vol.All(
|
||||||
|
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
supports_response=SupportsResponse.ONLY,
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
@ -23,6 +23,7 @@ ATTR_INSTRUCTIONS: Final = "instructions"
|
|||||||
ATTR_TASK_NAME: Final = "task_name"
|
ATTR_TASK_NAME: Final = "task_name"
|
||||||
ATTR_STRUCTURE: Final = "structure"
|
ATTR_STRUCTURE: Final = "structure"
|
||||||
ATTR_REQUIRED: Final = "required"
|
ATTR_REQUIRED: Final = "required"
|
||||||
|
ATTR_ATTACHMENTS: Final = "attachments"
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = (
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
"You are a Home Assistant expert and help users with their tasks."
|
"You are a Home Assistant expert and help users with their tasks."
|
||||||
@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):
|
|||||||
|
|
||||||
GENERATE_DATA = 1
|
GENERATE_DATA = 1
|
||||||
"""Generate data based on instructions."""
|
"""Generate data based on instructions."""
|
||||||
|
|
||||||
|
SUPPORT_ATTACHMENTS = 2
|
||||||
|
"""Support attachments with generate data."""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"domain": "ai_task",
|
"domain": "ai_task",
|
||||||
"name": "AI Task",
|
"name": "AI Task",
|
||||||
"codeowners": ["@home-assistant/core"],
|
"codeowners": ["@home-assistant/core"],
|
||||||
"dependencies": ["conversation"],
|
"dependencies": ["conversation", "media_source"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||||
"integration_type": "system",
|
"integration_type": "system",
|
||||||
"quality_scale": "internal"
|
"quality_scale": "internal"
|
||||||
|
@ -23,3 +23,9 @@ generate_data:
|
|||||||
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
|
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
|
||||||
selector:
|
selector:
|
||||||
object:
|
object:
|
||||||
|
attachments:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
media:
|
||||||
|
accept:
|
||||||
|
- "*"
|
||||||
|
@ -19,6 +19,10 @@
|
|||||||
"structure": {
|
"structure": {
|
||||||
"name": "Structured output",
|
"name": "Structured output",
|
||||||
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
|
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
|
||||||
|
},
|
||||||
|
"attachments": {
|
||||||
|
"name": "Attachments",
|
||||||
|
"description": "List of files to attach for multi-modal AI analysis."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,17 +2,30 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PlayMediaWithId(media_source.PlayMedia):
|
||||||
|
"""Play media with a media content ID."""
|
||||||
|
|
||||||
|
media_content_id: str
|
||||||
|
"""Media source ID to play."""
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return media source ID as a string."""
|
||||||
|
return f"<PlayMediaWithId {self.media_content_id}>"
|
||||||
|
|
||||||
|
|
||||||
async def async_generate_data(
|
async def async_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
*,
|
*,
|
||||||
@ -20,6 +33,7 @@ async def async_generate_data(
|
|||||||
entity_id: str | None = None,
|
entity_id: str | None = None,
|
||||||
instructions: str,
|
instructions: str,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
|
attachments: list[dict] | None = None,
|
||||||
) -> GenDataTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a task in the AI Task integration."""
|
"""Run a task in the AI Task integration."""
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
@ -37,11 +51,37 @@ async def async_generate_data(
|
|||||||
f"AI Task entity {entity_id} does not support generating data"
|
f"AI Task entity {entity_id} does not support generating data"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Resolve attachments
|
||||||
|
resolved_attachments: list[PlayMediaWithId] | None = None
|
||||||
|
|
||||||
|
if attachments:
|
||||||
|
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"AI Task entity {entity_id} does not support attachments"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_attachments = []
|
||||||
|
|
||||||
|
for attachment in attachments:
|
||||||
|
media = await media_source.async_resolve_media(
|
||||||
|
hass, attachment["media_content_id"], None
|
||||||
|
)
|
||||||
|
resolved_attachments.append(
|
||||||
|
PlayMediaWithId(
|
||||||
|
**{
|
||||||
|
field.name: getattr(media, field.name)
|
||||||
|
for field in fields(media)
|
||||||
|
},
|
||||||
|
media_content_id=attachment["media_content_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return await entity.internal_async_generate_data(
|
return await entity.internal_async_generate_data(
|
||||||
GenDataTask(
|
GenDataTask(
|
||||||
name=task_name,
|
name=task_name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
structure=structure,
|
structure=structure,
|
||||||
|
attachments=resolved_attachments,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,6 +99,9 @@ class GenDataTask:
|
|||||||
structure: vol.Schema | None = None
|
structure: vol.Schema | None = None
|
||||||
"""Optional structure for the data to be generated."""
|
"""Optional structure for the data to be generated."""
|
||||||
|
|
||||||
|
attachments: list[PlayMediaWithId] | None = None
|
||||||
|
"""List of attachments to go along the instructions."""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return task as a string."""
|
"""Return task as a string."""
|
||||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||||
|
@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
|
|||||||
"""Mock AI Task entity for testing."""
|
"""Mock AI Task entity for testing."""
|
||||||
|
|
||||||
_attr_name = "Test Task Entity"
|
_attr_name = "Test Task Entity"
|
||||||
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
_attr_supported_features = (
|
||||||
|
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the mock entity."""
|
"""Initialize the mock entity."""
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
"""Test initialization of the AI Task component."""
|
"""Test initialization of the AI Task component."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import media_source
|
||||||
from homeassistant.components.ai_task import AITaskPreferences
|
from homeassistant.components.ai_task import AITaskPreferences
|
||||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -58,7 +60,15 @@ async def test_preferences_storage_load(
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
{},
|
{},
|
||||||
{"entity_id": TEST_ENTITY_ID},
|
{
|
||||||
|
"entity_id": TEST_ENTITY_ID,
|
||||||
|
"attachments": [
|
||||||
|
{
|
||||||
|
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
|
||||||
|
"media_content_type": "video/mp4",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -68,25 +78,50 @@ async def test_generate_data_service(
|
|||||||
freezer: FrozenDateTimeFactory,
|
freezer: FrozenDateTimeFactory,
|
||||||
set_preferences: dict[str, str | None],
|
set_preferences: dict[str, str | None],
|
||||||
msg_extra: dict[str, str],
|
msg_extra: dict[str, str],
|
||||||
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the generate data service."""
|
"""Test the generate data service."""
|
||||||
preferences = hass.data[DATA_PREFERENCES]
|
preferences = hass.data[DATA_PREFERENCES]
|
||||||
preferences.async_set_preferences(**set_preferences)
|
preferences.async_set_preferences(**set_preferences)
|
||||||
|
|
||||||
result = await hass.services.async_call(
|
with patch(
|
||||||
"ai_task",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
"generate_data",
|
return_value=media_source.PlayMedia(
|
||||||
{
|
url="http://example.com/media.mp4",
|
||||||
"task_name": "Test Name",
|
mime_type="video/mp4",
|
||||||
"instructions": "Test prompt",
|
),
|
||||||
}
|
):
|
||||||
| msg_extra,
|
result = await hass.services.async_call(
|
||||||
blocking=True,
|
"ai_task",
|
||||||
return_response=True,
|
"generate_data",
|
||||||
)
|
{
|
||||||
|
"task_name": "Test Name",
|
||||||
|
"instructions": "Test prompt",
|
||||||
|
}
|
||||||
|
| msg_extra,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
assert result["data"] == "Mock result"
|
assert result["data"] == "Mock result"
|
||||||
|
|
||||||
|
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
|
||||||
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
|
|
||||||
|
assert len(task.attachments or []) == len(
|
||||||
|
msg_attachments := msg_extra.get("attachments", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
for msg_attachment, attachment in zip(
|
||||||
|
msg_attachments, task.attachments or [], strict=False
|
||||||
|
):
|
||||||
|
assert attachment.url == "http://example.com/media.mp4"
|
||||||
|
assert attachment.mime_type == "video/mp4"
|
||||||
|
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||||
|
assert (
|
||||||
|
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_generate_data_service_structure_fields(
|
async def test_generate_data_service_structure_fields(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -16,13 +16,13 @@ from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
|||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
async def test_run_task_preferred_entity(
|
async def test_generate_data_preferred_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
mock_ai_task_entity: MockAITaskEntity,
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test running a task with an unknown entity."""
|
"""Test generating data with entity via preferences."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_run_data_task_unknown_entity(
|
async def test_generate_data_unknown_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test running a data task with an unknown entity."""
|
"""Test generating data with an unknown entity."""
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||||
@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
|
|||||||
init_components: None,
|
init_components: None,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that running a data task updates the chat log."""
|
"""Test that generating data updates the chat log."""
|
||||||
result = await async_generate_data(
|
result = await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
|
|||||||
async_get_chat_log(hass, session) as chat_log,
|
async_get_chat_log(hass, session) as chat_log,
|
||||||
):
|
):
|
||||||
assert chat_log.content == snapshot
|
assert chat_log.content == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_data_attachments_not_supported(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test generating data with attachments when entity doesn't support them."""
|
||||||
|
# Remove attachment support from the entity
|
||||||
|
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match="AI Task entity ai_task.test_task_entity does not support attachments",
|
||||||
|
):
|
||||||
|
await async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=TEST_ENTITY_ID,
|
||||||
|
instructions="Test prompt",
|
||||||
|
attachments=[
|
||||||
|
{
|
||||||
|
"media_content_id": "media-source://mock/test.mp4",
|
||||||
|
"media_content_type": "video/mp4",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user