Add attachment support to Google Gemini (#148208)

This commit is contained in:
Paulus Schoutsen 2025-07-10 23:45:11 +02:00 committed by GitHub
parent a2220cc2e6
commit 19b3b6cb28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 12 deletions

View File

@ -37,7 +37,10 @@ class GoogleGenerativeAITaskEntity(
):
"""Google Generative AI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
_attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
async def _async_generate_data(
self,
@ -45,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.structure)
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(

View File

@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Callable
from dataclasses import replace
import mimetypes
from pathlib import Path
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from google.genai import Client
from google.genai.errors import APIError, ClientError
@ -30,8 +30,8 @@ from google.genai.types import (
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
@ -60,6 +60,9 @@ from .const import (
TIMEOUT_MILLIS,
)
if TYPE_CHECKING:
from . import GoogleGenerativeAIConfigEntry
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
@ -313,7 +316,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
def __init__(
self,
entry: ConfigEntry,
entry: GoogleGenerativeAIConfigEntry,
subentry: ConfigSubentry,
default_model: str = RECOMMENDED_CHAT_MODEL,
) -> None:
@ -335,6 +338,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
self,
chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
attachments: list[ai_task.PlayMediaWithId] | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
@ -438,6 +442,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
user_message = chat_log.content[-1]
assert isinstance(user_message, conversation.UserContent)
chat_request: str | list[Part] = user_message.content
if attachments:
if any(a.path is None for a in attachments):
raise HomeAssistantError(
"Only local attachments are currently supported"
)
files = await async_prepare_files_for_prompt(
self.hass,
self._genai_client,
[a.path for a in attachments], # type: ignore[misc]
)
chat_request = [chat_request, *files]
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
@ -508,7 +524,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
async def async_prepare_files_for_prompt(
hass: HomeAssistant, client: Client, files: list[Path]
) -> list[File]:
"""Append files to a prompt.
"""Upload files so they can be attached to a prompt.
Caller needs to ensure that the files are allowed.
"""

View File

@ -1,12 +1,13 @@
"""Test AI Task platform of Google Generative AI Conversation integration."""
from unittest.mock import AsyncMock
from pathlib import Path
from unittest.mock import AsyncMock, patch
from google.genai.types import GenerateContentResponse
from google.genai.types import File, FileState, GenerateContentResponse
import pytest
import voluptuous as vol
from homeassistant.components import ai_task
from homeassistant.components import ai_task, media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
@ -64,6 +65,93 @@ async def test_generate_data(
)
assert result.data == "Hi there!"
# Test with attachments
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "Hi there!"}],
"role": "model",
},
}
],
),
],
]
file1 = File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE)
file2 = File(name="context.txt", state=FileState.ACTIVE)
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
media_source.PlayMedia(
url="http://example.com/context.txt",
mime_type="text/plain",
path=Path("context.txt"),
),
],
),
patch(
"google.genai.files.Files.upload",
side_effect=[file1, file2],
) as mock_upload,
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
{"media_content_id": "media-source://media/context.txt"},
],
)
outgoing_message = mock_send_message_stream.mock_calls[1][2]["message"]
assert outgoing_message == ["Test prompt", file1, file2]
assert result.data == "Hi there!"
assert len(mock_upload.mock_calls) == 2
assert mock_upload.mock_calls[0][2]["file"] == Path("doorbell_snapshot.jpg")
assert mock_upload.mock_calls[1][2]["file"] == Path("context.txt")
# Test attachments require play media with a path
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=None,
),
],
),
pytest.raises(
HomeAssistantError, match="Only local attachments are currently supported"
),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
],
)
# Test with structure
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
@ -97,7 +185,7 @@ async def test_generate_data(
)
assert result.data == {"characters": ["Mario", "Luigi"]}
assert len(mock_chat_create.mock_calls) == 2
assert len(mock_chat_create.mock_calls) == 4
config = mock_chat_create.mock_calls[-1][2]["config"]
assert config.response_mime_type == "application/json"
assert config.response_schema == {

View File

@ -87,7 +87,6 @@ async def test_generate_content_service_with_image(
),
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("builtins.open", mock_open(read_data="this is an image")),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
):
response = await hass.services.async_call(