Add attachment support to AI task (#148120)

This commit is contained in:
Paulus Schoutsen 2025-07-06 19:33:41 +02:00 committed by GitHub
parent 699c60f293
commit 008e2a3d10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 147 additions and 21 deletions

View File

@ -20,6 +20,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import ( from .const import (
ATTR_ATTACHMENTS,
ATTR_INSTRUCTIONS, ATTR_INSTRUCTIONS,
ATTR_REQUIRED, ATTR_REQUIRED,
ATTR_STRUCTURE, ATTR_STRUCTURE,
@ -32,7 +33,7 @@ from .const import (
) )
from .entity import AITaskEntity from .entity import AITaskEntity
from .http import async_setup as async_setup_http from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, async_generate_data from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
@ -40,6 +41,7 @@ __all__ = [
"AITaskEntityFeature", "AITaskEntityFeature",
"GenDataTask", "GenDataTask",
"GenDataTaskResult", "GenDataTaskResult",
"PlayMediaWithId",
"async_generate_data", "async_generate_data",
"async_setup", "async_setup",
"async_setup_entry", "async_setup_entry",
@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}), vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields, _validate_structure_fields,
), ),
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
),
} }
), ),
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,

View File

@ -23,6 +23,7 @@ ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name" ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure" ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required" ATTR_REQUIRED: Final = "required"
ATTR_ATTACHMENTS: Final = "attachments"
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks." "You are a Home Assistant expert and help users with their tasks."
@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):
GENERATE_DATA = 1 GENERATE_DATA = 1
"""Generate data based on instructions.""" """Generate data based on instructions."""
SUPPORT_ATTACHMENTS = 2
"""Support attachments with generate data."""

View File

@ -2,7 +2,7 @@
"domain": "ai_task", "domain": "ai_task",
"name": "AI Task", "name": "AI Task",
"codeowners": ["@home-assistant/core"], "codeowners": ["@home-assistant/core"],
"dependencies": ["conversation"], "dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task", "documentation": "https://www.home-assistant.io/integrations/ai_task",
"integration_type": "system", "integration_type": "system",
"quality_scale": "internal" "quality_scale": "internal"

View File

@ -23,3 +23,9 @@ generate_data:
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }' example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector: selector:
object: object:
attachments:
required: false
selector:
media:
accept:
- "*"

View File

@ -19,6 +19,10 @@
"structure": { "structure": {
"name": "Structured output", "name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field." "description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
},
"attachments": {
"name": "Attachments",
"description": "List of files to attach for multi-modal AI analysis."
} }
} }
} }

View File

@ -2,17 +2,30 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, fields
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
@dataclass(slots=True)
class PlayMediaWithId(media_source.PlayMedia):
"""Play media with a media content ID."""
media_content_id: str
"""Media source ID to play."""
def __str__(self) -> str:
"""Return media source ID as a string."""
return f"<PlayMediaWithId {self.media_content_id}>"
async def async_generate_data( async def async_generate_data(
hass: HomeAssistant, hass: HomeAssistant,
*, *,
@ -20,6 +33,7 @@ async def async_generate_data(
entity_id: str | None = None, entity_id: str | None = None,
instructions: str, instructions: str,
structure: vol.Schema | None = None, structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Run a task in the AI Task integration.""" """Run a task in the AI Task integration."""
if entity_id is None: if entity_id is None:
@ -37,11 +51,37 @@ async def async_generate_data(
f"AI Task entity {entity_id} does not support generating data" f"AI Task entity {entity_id} does not support generating data"
) )
# Resolve attachments
resolved_attachments: list[PlayMediaWithId] | None = None
if attachments:
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
resolved_attachments = []
for attachment in attachments:
media = await media_source.async_resolve_media(
hass, attachment["media_content_id"], None
)
resolved_attachments.append(
PlayMediaWithId(
**{
field.name: getattr(media, field.name)
for field in fields(media)
},
media_content_id=attachment["media_content_id"],
)
)
return await entity.internal_async_generate_data( return await entity.internal_async_generate_data(
GenDataTask( GenDataTask(
name=task_name, name=task_name,
instructions=instructions, instructions=instructions,
structure=structure, structure=structure,
attachments=resolved_attachments,
) )
) )
@ -59,6 +99,9 @@ class GenDataTask:
structure: vol.Schema | None = None structure: vol.Schema | None = None
"""Optional structure for the data to be generated.""" """Optional structure for the data to be generated."""
attachments: list[PlayMediaWithId] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str: def __str__(self) -> str:
"""Return task as a string.""" """Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>" return f"<GenDataTask {self.name}: {id(self)}>"

View File

@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing.""" """Mock AI Task entity for testing."""
_attr_name = "Test Task Entity" _attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA _attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the mock entity.""" """Initialize the mock entity."""

View File

@ -1,11 +1,13 @@
"""Test initialization of the AI Task component.""" """Test initialization of the AI Task component."""
from typing import Any from typing import Any
from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -58,7 +60,15 @@ async def test_preferences_storage_load(
), ),
( (
{}, {},
{"entity_id": TEST_ENTITY_ID}, {
"entity_id": TEST_ENTITY_ID,
"attachments": [
{
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
"media_content_type": "video/mp4",
}
],
},
), ),
], ],
) )
@ -68,25 +78,50 @@ async def test_generate_data_service(
freezer: FrozenDateTimeFactory, freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None], set_preferences: dict[str, str | None],
msg_extra: dict[str, str], msg_extra: dict[str, str],
mock_ai_task_entity: MockAITaskEntity,
) -> None: ) -> None:
"""Test the generate data service.""" """Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES] preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences) preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call( with patch(
"ai_task", "homeassistant.components.media_source.async_resolve_media",
"generate_data", return_value=media_source.PlayMedia(
{ url="http://example.com/media.mp4",
"task_name": "Test Name", mime_type="video/mp4",
"instructions": "Test prompt", ),
} ):
| msg_extra, result = await hass.services.async_call(
blocking=True, "ai_task",
return_response=True, "generate_data",
) {
"task_name": "Test Name",
"instructions": "Test prompt",
}
| msg_extra,
blocking=True,
return_response=True,
)
assert result["data"] == "Mock result" assert result["data"] == "Mock result"
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert len(task.attachments or []) == len(
msg_attachments := msg_extra.get("attachments", [])
)
for msg_attachment, attachment in zip(
msg_attachments, task.attachments or [], strict=False
):
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"]
assert (
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
)
async def test_generate_data_service_structure_fields( async def test_generate_data_service_structure_fields(
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -16,13 +16,13 @@ from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
async def test_run_task_preferred_entity( async def test_generate_data_preferred_entity(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
mock_ai_task_entity: MockAITaskEntity, mock_ai_task_entity: MockAITaskEntity,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
) -> None: ) -> None:
"""Test running a task with an unknown entity.""" """Test generating data with entity via preferences."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
with pytest.raises( with pytest.raises(
@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
) )
async def test_run_data_task_unknown_entity( async def test_generate_data_unknown_entity(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
) -> None: ) -> None:
"""Test running a data task with an unknown entity.""" """Test generating data with an unknown entity."""
with pytest.raises( with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found" HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
init_components: None, init_components: None,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that running a data task updates the chat log.""" """Test that generating data updates the chat log."""
result = await async_generate_data( result = await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
async_get_chat_log(hass, session) as chat_log, async_get_chat_log(hass, session) as chat_log,
): ):
assert chat_log.content == snapshot assert chat_log.content == snapshot
async def test_generate_data_attachments_not_supported(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating data with attachments when entity doesn't support them."""
# Remove attachment support from the entity
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support attachments",
):
await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
attachments=[
{
"media_content_id": "media-source://mock/test.mp4",
"media_content_type": "video/mp4",
}
],
)