mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Make attachments native to chat log (#148693)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
f3ad6bd9b6
commit
23a8442abe
@ -33,7 +33,7 @@ from .const import (
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_http
|
||||
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
|
||||
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@ -41,7 +41,6 @@ __all__ = [
|
||||
"AITaskEntityFeature",
|
||||
"GenDataTask",
|
||||
"GenDataTaskResult",
|
||||
"PlayMediaWithId",
|
||||
"async_generate_data",
|
||||
"async_setup",
|
||||
"async_setup_entry",
|
||||
|
@ -79,7 +79,9 @@ class AITaskEntity(RestoreEntity):
|
||||
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
chat_log.async_add_user_content(UserContent(task.instructions))
|
||||
chat_log.async_add_user_content(
|
||||
UserContent(task.instructions, attachments=task.attachments)
|
||||
)
|
||||
|
||||
yield chat_log
|
||||
|
||||
|
@ -2,30 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components import conversation, media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PlayMediaWithId(media_source.PlayMedia):
|
||||
"""Play media with a media content ID."""
|
||||
|
||||
media_content_id: str
|
||||
"""Media source ID to play."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return media source ID as a string."""
|
||||
return f"<PlayMediaWithId {self.media_content_id}>"
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
@ -52,7 +40,7 @@ async def async_generate_data(
|
||||
)
|
||||
|
||||
# Resolve attachments
|
||||
resolved_attachments: list[PlayMediaWithId] | None = None
|
||||
resolved_attachments: list[conversation.Attachment] | None = None
|
||||
|
||||
if attachments:
|
||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||
@ -66,13 +54,16 @@ async def async_generate_data(
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, attachment["media_content_id"], None
|
||||
)
|
||||
if media.path is None:
|
||||
raise HomeAssistantError(
|
||||
"Only local attachments are currently supported"
|
||||
)
|
||||
resolved_attachments.append(
|
||||
PlayMediaWithId(
|
||||
**{
|
||||
field.name: getattr(media, field.name)
|
||||
for field in fields(media)
|
||||
},
|
||||
conversation.Attachment(
|
||||
media_content_id=attachment["media_content_id"],
|
||||
url=media.url,
|
||||
mime_type=media.mime_type,
|
||||
path=media.path,
|
||||
)
|
||||
)
|
||||
|
||||
@ -99,7 +90,7 @@ class GenDataTask:
|
||||
structure: vol.Schema | None = None
|
||||
"""Optional structure for the data to be generated."""
|
||||
|
||||
attachments: list[PlayMediaWithId] | None = None
|
||||
attachments: list[conversation.Attachment] | None = None
|
||||
"""List of attachments to go along the instructions."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -34,6 +34,7 @@ from .agent_manager import (
|
||||
from .chat_log import (
|
||||
AssistantContent,
|
||||
AssistantContentDeltaDict,
|
||||
Attachment,
|
||||
ChatLog,
|
||||
Content,
|
||||
ConverseError,
|
||||
@ -66,6 +67,7 @@ __all__ = [
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"AssistantContent",
|
||||
"AssistantContentDeltaDict",
|
||||
"Attachment",
|
||||
"ChatLog",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
|
@ -8,6 +8,7 @@ from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import asdict, dataclass, field, replace
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
import voluptuous as vol
|
||||
@ -136,6 +137,24 @@ class UserContent:
|
||||
|
||||
role: Literal["user"] = field(init=False, default="user")
|
||||
content: str
|
||||
attachments: list[Attachment] | None = field(default=None)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Attachment:
|
||||
"""Attachment for a chat message."""
|
||||
|
||||
media_content_id: str
|
||||
"""Media content ID of the attachment."""
|
||||
|
||||
url: str
|
||||
"""URL of the attachment."""
|
||||
|
||||
mime_type: str
|
||||
"""MIME type of the attachment."""
|
||||
|
||||
path: Path
|
||||
"""Path to the attachment on disk."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -48,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
|
||||
await self._async_handle_chat_log(chat_log, task.structure)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
|
@ -30,7 +30,7 @@ from google.genai.types import (
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -338,7 +338,6 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[ai_task.PlayMediaWithId] | None = None,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
@ -442,15 +441,11 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
user_message = chat_log.content[-1]
|
||||
assert isinstance(user_message, conversation.UserContent)
|
||||
chat_request: str | list[Part] = user_message.content
|
||||
if attachments:
|
||||
if any(a.path is None for a in attachments):
|
||||
raise HomeAssistantError(
|
||||
"Only local attachments are currently supported"
|
||||
)
|
||||
if user_message.attachments:
|
||||
files = await async_prepare_files_for_prompt(
|
||||
self.hass,
|
||||
self._genai_client,
|
||||
[a.path for a in attachments], # type: ignore[misc]
|
||||
[a.path for a in user_message.attachments],
|
||||
)
|
||||
chat_request = [chat_request, *files]
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'Test prompt',
|
||||
'role': 'user',
|
||||
}),
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -89,6 +90,7 @@ async def test_generate_data_service(
|
||||
return_value=media_source.PlayMedia(
|
||||
url="http://example.com/media.mp4",
|
||||
mime_type="video/mp4",
|
||||
path=Path("media.mp4"),
|
||||
),
|
||||
):
|
||||
result = await hass.services.async_call(
|
||||
@ -118,9 +120,7 @@ async def test_generate_data_service(
|
||||
assert attachment.url == "http://example.com/media.mp4"
|
||||
assert attachment.mime_type == "video/mp4"
|
||||
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||
assert (
|
||||
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
|
||||
)
|
||||
assert attachment.path == Path("media.mp4")
|
||||
|
||||
|
||||
async def test_generate_data_service_structure_fields(
|
||||
|
@ -12,6 +12,7 @@
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
|
@ -185,7 +185,7 @@ async def test_generate_data(
|
||||
)
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
|
||||
assert len(mock_chat_create.mock_calls) == 4
|
||||
assert len(mock_chat_create.mock_calls) == 3
|
||||
config = mock_chat_create.mock_calls[-1][2]["config"]
|
||||
assert config.response_mime_type == "application/json"
|
||||
assert config.response_schema == {
|
||||
|
@ -2,6 +2,7 @@
|
||||
# name: test_function_call
|
||||
list([
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
@ -58,6 +59,7 @@
|
||||
# name: test_function_call_without_reasoning
|
||||
list([
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
|
Loading…
x
Reference in New Issue
Block a user