mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Make attachments native to chat log (#148693)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
f3ad6bd9b6
commit
23a8442abe
@ -33,7 +33,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import AITaskEntity
|
from .entity import AITaskEntity
|
||||||
from .http import async_setup as async_setup_http
|
from .http import async_setup as async_setup_http
|
||||||
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
|
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
@ -41,7 +41,6 @@ __all__ = [
|
|||||||
"AITaskEntityFeature",
|
"AITaskEntityFeature",
|
||||||
"GenDataTask",
|
"GenDataTask",
|
||||||
"GenDataTaskResult",
|
"GenDataTaskResult",
|
||||||
"PlayMediaWithId",
|
|
||||||
"async_generate_data",
|
"async_generate_data",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_setup_entry",
|
"async_setup_entry",
|
||||||
|
@ -79,7 +79,9 @@ class AITaskEntity(RestoreEntity):
|
|||||||
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_log.async_add_user_content(UserContent(task.instructions))
|
chat_log.async_add_user_content(
|
||||||
|
UserContent(task.instructions, attachments=task.attachments)
|
||||||
|
)
|
||||||
|
|
||||||
yield chat_log
|
yield chat_log
|
||||||
|
|
||||||
|
@ -2,30 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import media_source
|
from homeassistant.components import conversation, media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class PlayMediaWithId(media_source.PlayMedia):
|
|
||||||
"""Play media with a media content ID."""
|
|
||||||
|
|
||||||
media_content_id: str
|
|
||||||
"""Media source ID to play."""
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""Return media source ID as a string."""
|
|
||||||
return f"<PlayMediaWithId {self.media_content_id}>"
|
|
||||||
|
|
||||||
|
|
||||||
async def async_generate_data(
|
async def async_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
*,
|
*,
|
||||||
@ -52,7 +40,7 @@ async def async_generate_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Resolve attachments
|
# Resolve attachments
|
||||||
resolved_attachments: list[PlayMediaWithId] | None = None
|
resolved_attachments: list[conversation.Attachment] | None = None
|
||||||
|
|
||||||
if attachments:
|
if attachments:
|
||||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||||
@ -66,13 +54,16 @@ async def async_generate_data(
|
|||||||
media = await media_source.async_resolve_media(
|
media = await media_source.async_resolve_media(
|
||||||
hass, attachment["media_content_id"], None
|
hass, attachment["media_content_id"], None
|
||||||
)
|
)
|
||||||
|
if media.path is None:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Only local attachments are currently supported"
|
||||||
|
)
|
||||||
resolved_attachments.append(
|
resolved_attachments.append(
|
||||||
PlayMediaWithId(
|
conversation.Attachment(
|
||||||
**{
|
|
||||||
field.name: getattr(media, field.name)
|
|
||||||
for field in fields(media)
|
|
||||||
},
|
|
||||||
media_content_id=attachment["media_content_id"],
|
media_content_id=attachment["media_content_id"],
|
||||||
|
url=media.url,
|
||||||
|
mime_type=media.mime_type,
|
||||||
|
path=media.path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -99,7 +90,7 @@ class GenDataTask:
|
|||||||
structure: vol.Schema | None = None
|
structure: vol.Schema | None = None
|
||||||
"""Optional structure for the data to be generated."""
|
"""Optional structure for the data to be generated."""
|
||||||
|
|
||||||
attachments: list[PlayMediaWithId] | None = None
|
attachments: list[conversation.Attachment] | None = None
|
||||||
"""List of attachments to go along the instructions."""
|
"""List of attachments to go along the instructions."""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
@ -34,6 +34,7 @@ from .agent_manager import (
|
|||||||
from .chat_log import (
|
from .chat_log import (
|
||||||
AssistantContent,
|
AssistantContent,
|
||||||
AssistantContentDeltaDict,
|
AssistantContentDeltaDict,
|
||||||
|
Attachment,
|
||||||
ChatLog,
|
ChatLog,
|
||||||
Content,
|
Content,
|
||||||
ConverseError,
|
ConverseError,
|
||||||
@ -66,6 +67,7 @@ __all__ = [
|
|||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"AssistantContent",
|
"AssistantContent",
|
||||||
"AssistantContentDeltaDict",
|
"AssistantContentDeltaDict",
|
||||||
|
"Attachment",
|
||||||
"ChatLog",
|
"ChatLog",
|
||||||
"Content",
|
"Content",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
|
@ -8,6 +8,7 @@ from contextlib import contextmanager
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import asdict, dataclass, field, replace
|
from dataclasses import asdict, dataclass, field, replace
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypedDict
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -136,6 +137,24 @@ class UserContent:
|
|||||||
|
|
||||||
role: Literal["user"] = field(init=False, default="user")
|
role: Literal["user"] = field(init=False, default="user")
|
||||||
content: str
|
content: str
|
||||||
|
attachments: list[Attachment] | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Attachment:
|
||||||
|
"""Attachment for a chat message."""
|
||||||
|
|
||||||
|
media_content_id: str
|
||||||
|
"""Media content ID of the attachment."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
"""URL of the attachment."""
|
||||||
|
|
||||||
|
mime_type: str
|
||||||
|
"""MIME type of the attachment."""
|
||||||
|
|
||||||
|
path: Path
|
||||||
|
"""Path to the attachment on disk."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
@ -48,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> ai_task.GenDataTaskResult:
|
) -> ai_task.GenDataTaskResult:
|
||||||
"""Handle a generate data task."""
|
"""Handle a generate data task."""
|
||||||
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
|
await self._async_handle_chat_log(chat_log, task.structure)
|
||||||
|
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
|
@ -30,7 +30,7 @@ from google.genai.types import (
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import ai_task, conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -338,7 +338,6 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
self,
|
self,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
attachments: list[ai_task.PlayMediaWithId] | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@ -442,15 +441,11 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
user_message = chat_log.content[-1]
|
user_message = chat_log.content[-1]
|
||||||
assert isinstance(user_message, conversation.UserContent)
|
assert isinstance(user_message, conversation.UserContent)
|
||||||
chat_request: str | list[Part] = user_message.content
|
chat_request: str | list[Part] = user_message.content
|
||||||
if attachments:
|
if user_message.attachments:
|
||||||
if any(a.path is None for a in attachments):
|
|
||||||
raise HomeAssistantError(
|
|
||||||
"Only local attachments are currently supported"
|
|
||||||
)
|
|
||||||
files = await async_prepare_files_for_prompt(
|
files = await async_prepare_files_for_prompt(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._genai_client,
|
self._genai_client,
|
||||||
[a.path for a in attachments], # type: ignore[misc]
|
[a.path for a in user_message.attachments],
|
||||||
)
|
)
|
||||||
chat_request = [chat_request, *files]
|
chat_request = [chat_request, *files]
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
'role': 'system',
|
'role': 'system',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
|
'attachments': None,
|
||||||
'content': 'Test prompt',
|
'content': 'Test prompt',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test initialization of the AI Task component."""
|
"""Test initialization of the AI Task component."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -89,6 +90,7 @@ async def test_generate_data_service(
|
|||||||
return_value=media_source.PlayMedia(
|
return_value=media_source.PlayMedia(
|
||||||
url="http://example.com/media.mp4",
|
url="http://example.com/media.mp4",
|
||||||
mime_type="video/mp4",
|
mime_type="video/mp4",
|
||||||
|
path=Path("media.mp4"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await hass.services.async_call(
|
result = await hass.services.async_call(
|
||||||
@ -118,9 +120,7 @@ async def test_generate_data_service(
|
|||||||
assert attachment.url == "http://example.com/media.mp4"
|
assert attachment.url == "http://example.com/media.mp4"
|
||||||
assert attachment.mime_type == "video/mp4"
|
assert attachment.mime_type == "video/mp4"
|
||||||
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||||
assert (
|
assert attachment.path == Path("media.mp4")
|
||||||
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_generate_data_service_structure_fields(
|
async def test_generate_data_service_structure_fields(
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
'role': 'system',
|
'role': 'system',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
|
'attachments': None,
|
||||||
'content': 'Please call the test function',
|
'content': 'Please call the test function',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
|
@ -185,7 +185,7 @@ async def test_generate_data(
|
|||||||
)
|
)
|
||||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||||
|
|
||||||
assert len(mock_chat_create.mock_calls) == 4
|
assert len(mock_chat_create.mock_calls) == 3
|
||||||
config = mock_chat_create.mock_calls[-1][2]["config"]
|
config = mock_chat_create.mock_calls[-1][2]["config"]
|
||||||
assert config.response_mime_type == "application/json"
|
assert config.response_mime_type == "application/json"
|
||||||
assert config.response_schema == {
|
assert config.response_schema == {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# name: test_function_call
|
# name: test_function_call
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
'attachments': None,
|
||||||
'content': 'Please call the test function',
|
'content': 'Please call the test function',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
@ -58,6 +59,7 @@
|
|||||||
# name: test_function_call_without_reasoning
|
# name: test_function_call_without_reasoning
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
'attachments': None,
|
||||||
'content': 'Please call the test function',
|
'content': 'Please call the test function',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user