Make attachments native to chat log (#148693)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Paulus Schoutsen 2025-07-13 19:35:11 +02:00 committed by GitHub
parent f3ad6bd9b6
commit 23a8442abe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 49 additions and 37 deletions

View File

@ -33,7 +33,7 @@ from .const import (
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
from .task import GenDataTask, GenDataTaskResult, async_generate_data
__all__ = [
"DOMAIN",
@ -41,7 +41,6 @@ __all__ = [
"AITaskEntityFeature",
"GenDataTask",
"GenDataTaskResult",
"PlayMediaWithId",
"async_generate_data",
"async_setup",
"async_setup_entry",

View File

@ -79,7 +79,9 @@ class AITaskEntity(RestoreEntity):
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
)
chat_log.async_add_user_content(UserContent(task.instructions))
chat_log.async_add_user_content(
UserContent(task.instructions, attachments=task.attachments)
)
yield chat_log

View File

@ -2,30 +2,18 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from dataclasses import dataclass
from typing import Any
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components import conversation, media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
@dataclass(slots=True)
class PlayMediaWithId(media_source.PlayMedia):
"""Play media with a media content ID."""
media_content_id: str
"""Media source ID to play."""
def __str__(self) -> str:
"""Return media source ID as a string."""
return f"<PlayMediaWithId {self.media_content_id}>"
async def async_generate_data(
hass: HomeAssistant,
*,
@ -52,7 +40,7 @@ async def async_generate_data(
)
# Resolve attachments
resolved_attachments: list[PlayMediaWithId] | None = None
resolved_attachments: list[conversation.Attachment] | None = None
if attachments:
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
@ -66,13 +54,16 @@ async def async_generate_data(
media = await media_source.async_resolve_media(
hass, attachment["media_content_id"], None
)
if media.path is None:
raise HomeAssistantError(
"Only local attachments are currently supported"
)
resolved_attachments.append(
PlayMediaWithId(
**{
field.name: getattr(media, field.name)
for field in fields(media)
},
conversation.Attachment(
media_content_id=attachment["media_content_id"],
url=media.url,
mime_type=media.mime_type,
path=media.path,
)
)
@ -99,7 +90,7 @@ class GenDataTask:
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
attachments: list[PlayMediaWithId] | None = None
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str:

View File

@ -34,6 +34,7 @@ from .agent_manager import (
from .chat_log import (
AssistantContent,
AssistantContentDeltaDict,
Attachment,
ChatLog,
Content,
ConverseError,
@ -66,6 +67,7 @@ __all__ = [
"HOME_ASSISTANT_AGENT",
"AssistantContent",
"AssistantContentDeltaDict",
"Attachment",
"ChatLog",
"Content",
"ConversationEntity",

View File

@ -8,6 +8,7 @@ from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field, replace
import logging
from pathlib import Path
from typing import Any, Literal, TypedDict
import voluptuous as vol
@ -136,6 +137,24 @@ class UserContent:
role: Literal["user"] = field(init=False, default="user")
content: str
attachments: list[Attachment] | None = field(default=None)
@dataclass(frozen=True)
class Attachment:
"""Attachment for a chat message."""
media_content_id: str
"""Media content ID of the attachment."""
url: str
"""URL of the attachment."""
mime_type: str
"""MIME type of the attachment."""
path: Path
"""Path to the attachment on disk."""
@dataclass(frozen=True)

View File

@ -48,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
await self._async_handle_chat_log(chat_log, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(

View File

@ -30,7 +30,7 @@ from google.genai.types import (
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import ai_task, conversation
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -338,7 +338,6 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
self,
chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
attachments: list[ai_task.PlayMediaWithId] | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
@ -442,15 +441,11 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
user_message = chat_log.content[-1]
assert isinstance(user_message, conversation.UserContent)
chat_request: str | list[Part] = user_message.content
if attachments:
if any(a.path is None for a in attachments):
raise HomeAssistantError(
"Only local attachments are currently supported"
)
if user_message.attachments:
files = await async_prepare_files_for_prompt(
self.hass,
self._genai_client,
[a.path for a in attachments], # type: ignore[misc]
[a.path for a in user_message.attachments],
)
chat_request = [chat_request, *files]

View File

@ -9,6 +9,7 @@
'role': 'system',
}),
dict({
'attachments': None,
'content': 'Test prompt',
'role': 'user',
}),

View File

@ -1,5 +1,6 @@
"""Test initialization of the AI Task component."""
from pathlib import Path
from typing import Any
from unittest.mock import patch
@ -89,6 +90,7 @@ async def test_generate_data_service(
return_value=media_source.PlayMedia(
url="http://example.com/media.mp4",
mime_type="video/mp4",
path=Path("media.mp4"),
),
):
result = await hass.services.async_call(
@ -118,9 +120,7 @@ async def test_generate_data_service(
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"]
assert (
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
)
assert attachment.path == Path("media.mp4")
async def test_generate_data_service_structure_fields(

View File

@ -12,6 +12,7 @@
'role': 'system',
}),
dict({
'attachments': None,
'content': 'Please call the test function',
'role': 'user',
}),

View File

@ -185,7 +185,7 @@ async def test_generate_data(
)
assert result.data == {"characters": ["Mario", "Luigi"]}
assert len(mock_chat_create.mock_calls) == 4
assert len(mock_chat_create.mock_calls) == 3
config = mock_chat_create.mock_calls[-1][2]["config"]
assert config.response_mime_type == "application/json"
assert config.response_schema == {

View File

@ -2,6 +2,7 @@
# name: test_function_call
list([
dict({
'attachments': None,
'content': 'Please call the test function',
'role': 'user',
}),
@ -58,6 +59,7 @@
# name: test_function_call_without_reasoning
list([
dict({
'attachments': None,
'content': 'Please call the test function',
'role': 'user',
}),