mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 13:57:10 +00:00
Add attachment support to AI task (#148120)
This commit is contained in:
parent
699c60f293
commit
008e2a3d10
@ -20,6 +20,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
from .const import (
|
||||
ATTR_ATTACHMENTS,
|
||||
ATTR_INSTRUCTIONS,
|
||||
ATTR_REQUIRED,
|
||||
ATTR_STRUCTURE,
|
||||
@ -32,7 +33,7 @@ from .const import (
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_http
|
||||
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@ -40,6 +41,7 @@ __all__ = [
|
||||
"AITaskEntityFeature",
|
||||
"GenDataTask",
|
||||
"GenDataTaskResult",
|
||||
"PlayMediaWithId",
|
||||
"async_generate_data",
|
||||
"async_setup",
|
||||
"async_setup_entry",
|
||||
@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
|
||||
_validate_structure_fields,
|
||||
),
|
||||
vol.Optional(ATTR_ATTACHMENTS): vol.All(
|
||||
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
|
||||
),
|
||||
}
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
|
@ -23,6 +23,7 @@ ATTR_INSTRUCTIONS: Final = "instructions"
|
||||
ATTR_TASK_NAME: Final = "task_name"
|
||||
ATTR_STRUCTURE: Final = "structure"
|
||||
ATTR_REQUIRED: Final = "required"
|
||||
ATTR_ATTACHMENTS: Final = "attachments"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a Home Assistant expert and help users with their tasks."
|
||||
@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):
|
||||
|
||||
GENERATE_DATA = 1
|
||||
"""Generate data based on instructions."""
|
||||
|
||||
SUPPORT_ATTACHMENTS = 2
|
||||
"""Support attachments with generate data."""
|
||||
|
@ -2,7 +2,7 @@
|
||||
"domain": "ai_task",
|
||||
"name": "AI Task",
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"dependencies": ["conversation"],
|
||||
"dependencies": ["conversation", "media_source"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||
"integration_type": "system",
|
||||
"quality_scale": "internal"
|
||||
|
@ -23,3 +23,9 @@ generate_data:
|
||||
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
|
||||
selector:
|
||||
object:
|
||||
attachments:
|
||||
required: false
|
||||
selector:
|
||||
media:
|
||||
accept:
|
||||
- "*"
|
||||
|
@ -19,6 +19,10 @@
|
||||
"structure": {
|
||||
"name": "Structured output",
|
||||
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
|
||||
},
|
||||
"attachments": {
|
||||
"name": "Attachments",
|
||||
"description": "List of files to attach for multi-modal AI analysis."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,17 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PlayMediaWithId(media_source.PlayMedia):
|
||||
"""Play media with a media content ID."""
|
||||
|
||||
media_content_id: str
|
||||
"""Media source ID to play."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return media source ID as a string."""
|
||||
return f"<PlayMediaWithId {self.media_content_id}>"
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
@ -20,6 +33,7 @@ async def async_generate_data(
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
@ -37,11 +51,37 @@ async def async_generate_data(
|
||||
f"AI Task entity {entity_id} does not support generating data"
|
||||
)
|
||||
|
||||
# Resolve attachments
|
||||
resolved_attachments: list[PlayMediaWithId] | None = None
|
||||
|
||||
if attachments:
|
||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
resolved_attachments = []
|
||||
|
||||
for attachment in attachments:
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, attachment["media_content_id"], None
|
||||
)
|
||||
resolved_attachments.append(
|
||||
PlayMediaWithId(
|
||||
**{
|
||||
field.name: getattr(media, field.name)
|
||||
for field in fields(media)
|
||||
},
|
||||
media_content_id=attachment["media_content_id"],
|
||||
)
|
||||
)
|
||||
|
||||
return await entity.internal_async_generate_data(
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
attachments=resolved_attachments,
|
||||
)
|
||||
)
|
||||
|
||||
@ -59,6 +99,9 @@ class GenDataTask:
|
||||
structure: vol.Schema | None = None
|
||||
"""Optional structure for the data to be generated."""
|
||||
|
||||
attachments: list[PlayMediaWithId] | None = None
|
||||
"""List of attachments to go along the instructions."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||
|
@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
|
||||
"""Mock AI Task entity for testing."""
|
||||
|
||||
_attr_name = "Test Task Entity"
|
||||
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
||||
_attr_supported_features = (
|
||||
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -58,7 +60,15 @@ async def test_preferences_storage_load(
|
||||
),
|
||||
(
|
||||
{},
|
||||
{"entity_id": TEST_ENTITY_ID},
|
||||
{
|
||||
"entity_id": TEST_ENTITY_ID,
|
||||
"attachments": [
|
||||
{
|
||||
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
|
||||
"media_content_type": "video/mp4",
|
||||
}
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -68,11 +78,19 @@ async def test_generate_data_service(
|
||||
freezer: FrozenDateTimeFactory,
|
||||
set_preferences: dict[str, str | None],
|
||||
msg_extra: dict[str, str],
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the generate data service."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences.async_set_preferences(**set_preferences)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(
|
||||
url="http://example.com/media.mp4",
|
||||
mime_type="video/mp4",
|
||||
),
|
||||
):
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_data",
|
||||
@ -87,6 +105,23 @@ async def test_generate_data_service(
|
||||
|
||||
assert result["data"] == "Mock result"
|
||||
|
||||
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
|
||||
assert len(task.attachments or []) == len(
|
||||
msg_attachments := msg_extra.get("attachments", [])
|
||||
)
|
||||
|
||||
for msg_attachment, attachment in zip(
|
||||
msg_attachments, task.attachments or [], strict=False
|
||||
):
|
||||
assert attachment.url == "http://example.com/media.mp4"
|
||||
assert attachment.mime_type == "video/mp4"
|
||||
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||
assert (
|
||||
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_data_service_structure_fields(
|
||||
hass: HomeAssistant,
|
||||
|
@ -16,13 +16,13 @@ from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_run_task_preferred_entity(
|
||||
async def test_generate_data_preferred_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test running a task with an unknown entity."""
|
||||
"""Test generating data with entity via preferences."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with pytest.raises(
|
||||
@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
|
||||
)
|
||||
|
||||
|
||||
async def test_run_data_task_unknown_entity(
|
||||
async def test_generate_data_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test running a data task with an unknown entity."""
|
||||
"""Test generating data with an unknown entity."""
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||
@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
|
||||
init_components: None,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that running a data task updates the chat log."""
|
||||
"""Test that generating data updates the chat log."""
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
assert chat_log.content == snapshot
|
||||
|
||||
|
||||
async def test_generate_data_attachments_not_supported(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test generating data with attachments when entity doesn't support them."""
|
||||
# Remove attachment support from the entity
|
||||
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="AI Task entity ai_task.test_task_entity does not support attachments",
|
||||
):
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
attachments=[
|
||||
{
|
||||
"media_content_id": "media-source://mock/test.mp4",
|
||||
"media_content_type": "video/mp4",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user