Add attachment support to AI task (#148120)

This commit is contained in:
Paulus Schoutsen 2025-07-06 19:33:41 +02:00 committed by GitHub
parent 699c60f293
commit 008e2a3d10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 147 additions and 21 deletions

View File

@ -20,6 +20,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import (
ATTR_ATTACHMENTS,
ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
@ -32,7 +33,7 @@ from .const import (
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, async_generate_data
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
__all__ = [
"DOMAIN",
@ -40,6 +41,7 @@ __all__ = [
"AITaskEntityFeature",
"GenDataTask",
"GenDataTaskResult",
"PlayMediaWithId",
"async_generate_data",
"async_setup",
"async_setup_entry",
@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields,
),
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
),
}
),
supports_response=SupportsResponse.ONLY,

View File

@ -23,6 +23,7 @@ ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"
ATTR_ATTACHMENTS: Final = "attachments"
DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks."
@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):
GENERATE_DATA = 1
"""Generate data based on instructions."""
SUPPORT_ATTACHMENTS = 2
"""Support attachments with generate data."""

View File

@ -2,7 +2,7 @@
"domain": "ai_task",
"name": "AI Task",
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",
"integration_type": "system",
"quality_scale": "internal"

View File

@ -23,3 +23,9 @@ generate_data:
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:
attachments:
required: false
selector:
media:
accept:
- "*"

View File

@ -19,6 +19,10 @@
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
},
"attachments": {
"name": "Attachments",
"description": "List of files to attach for multi-modal AI analysis."
}
}
}

View File

@ -2,17 +2,30 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
@dataclass(slots=True)
class PlayMediaWithId(media_source.PlayMedia):
"""Play media with a media content ID."""
media_content_id: str
"""Media source ID to play."""
def __str__(self) -> str:
"""Return media source ID as a string."""
return f"<PlayMediaWithId {self.media_content_id}>"
async def async_generate_data(
hass: HomeAssistant,
*,
@ -20,6 +33,7 @@ async def async_generate_data(
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
@ -37,11 +51,37 @@ async def async_generate_data(
f"AI Task entity {entity_id} does not support generating data"
)
# Resolve attachments
resolved_attachments: list[PlayMediaWithId] | None = None
if attachments:
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
resolved_attachments = []
for attachment in attachments:
media = await media_source.async_resolve_media(
hass, attachment["media_content_id"], None
)
resolved_attachments.append(
PlayMediaWithId(
**{
field.name: getattr(media, field.name)
for field in fields(media)
},
media_content_id=attachment["media_content_id"],
)
)
return await entity.internal_async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments,
)
)
@ -59,6 +99,9 @@ class GenDataTask:
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
attachments: list[PlayMediaWithId] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"

View File

@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing."""
_attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
_attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
def __init__(self) -> None:
"""Initialize the mock entity."""

View File

@ -1,11 +1,13 @@
"""Test initialization of the AI Task component."""
from typing import Any
from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant
@ -58,7 +60,15 @@ async def test_preferences_storage_load(
),
(
{},
{"entity_id": TEST_ENTITY_ID},
{
"entity_id": TEST_ENTITY_ID,
"attachments": [
{
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
"media_content_type": "video/mp4",
}
],
},
),
],
)
@ -68,25 +78,50 @@ async def test_generate_data_service(
freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None],
msg_extra: dict[str, str],
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
}
| msg_extra,
blocking=True,
return_response=True,
)
with patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=media_source.PlayMedia(
url="http://example.com/media.mp4",
mime_type="video/mp4",
),
):
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
}
| msg_extra,
blocking=True,
return_response=True,
)
assert result["data"] == "Mock result"
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert len(task.attachments or []) == len(
msg_attachments := msg_extra.get("attachments", [])
)
for msg_attachment, attachment in zip(
msg_attachments, task.attachments or [], strict=False
):
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"]
assert (
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
)
async def test_generate_data_service_structure_fields(
hass: HomeAssistant,

View File

@ -16,13 +16,13 @@ from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.typing import WebSocketGenerator
async def test_run_task_preferred_entity(
async def test_generate_data_preferred_entity(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test running a task with an unknown entity."""
"""Test generating data with entity via preferences."""
client = await hass_ws_client(hass)
with pytest.raises(
@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
)
async def test_run_data_task_unknown_entity(
async def test_generate_data_unknown_entity(
hass: HomeAssistant,
init_components: None,
) -> None:
"""Test running a data task with an unknown entity."""
"""Test generating data with an unknown entity."""
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
init_components: None,
snapshot: SnapshotAssertion,
) -> None:
"""Test that running a data task updates the chat log."""
"""Test that generating data updates the chat log."""
result = await async_generate_data(
hass,
task_name="Test Task",
@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
async_get_chat_log(hass, session) as chat_log,
):
assert chat_log.content == snapshot
async def test_generate_data_attachments_not_supported(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating data with attachments when entity doesn't support them."""
# Remove attachment support from the entity
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support attachments",
):
await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
attachments=[
{
"media_content_id": "media-source://mock/test.mp4",
"media_content_type": "video/mp4",
}
],
)