mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Add attachment support to Google Gemini (#148208)
This commit is contained in:
parent
a2220cc2e6
commit
19b3b6cb28
@ -37,7 +37,10 @@ class GoogleGenerativeAITaskEntity(
|
||||
):
|
||||
"""Google Generative AI AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
_attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
@ -45,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log, task.structure)
|
||||
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
|
@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import replace
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
@ -30,8 +30,8 @@ from google.genai.types import (
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, llm
|
||||
@ -60,6 +60,9 @@ from .const import (
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import GoogleGenerativeAIConfigEntry
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
@ -313,7 +316,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entry: ConfigEntry,
|
||||
entry: GoogleGenerativeAIConfigEntry,
|
||||
subentry: ConfigSubentry,
|
||||
default_model: str = RECOMMENDED_CHAT_MODEL,
|
||||
) -> None:
|
||||
@ -335,6 +338,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[ai_task.PlayMediaWithId] | None = None,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
@ -438,6 +442,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
user_message = chat_log.content[-1]
|
||||
assert isinstance(user_message, conversation.UserContent)
|
||||
chat_request: str | list[Part] = user_message.content
|
||||
if attachments:
|
||||
if any(a.path is None for a in attachments):
|
||||
raise HomeAssistantError(
|
||||
"Only local attachments are currently supported"
|
||||
)
|
||||
files = await async_prepare_files_for_prompt(
|
||||
self.hass,
|
||||
self._genai_client,
|
||||
[a.path for a in attachments], # type: ignore[misc]
|
||||
)
|
||||
chat_request = [chat_request, *files]
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
@ -508,7 +524,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
async def async_prepare_files_for_prompt(
|
||||
hass: HomeAssistant, client: Client, files: list[Path]
|
||||
) -> list[File]:
|
||||
"""Append files to a prompt.
|
||||
"""Upload files so they can be attached to a prompt.
|
||||
|
||||
Caller needs to ensure that the files are allowed.
|
||||
"""
|
||||
|
@ -1,12 +1,13 @@
|
||||
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
from google.genai.types import File, FileState, GenerateContentResponse
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.components import ai_task, media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er, selector
|
||||
@ -64,6 +65,93 @@ async def test_generate_data(
|
||||
)
|
||||
assert result.data == "Hi there!"
|
||||
|
||||
# Test with attachments
|
||||
mock_send_message_stream.return_value = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hi there!"}],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
file1 = File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE)
|
||||
file2 = File(name="context.txt", state=FileState.ACTIVE)
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
side_effect=[
|
||||
media_source.PlayMedia(
|
||||
url="http://example.com/doorbell_snapshot.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=Path("doorbell_snapshot.jpg"),
|
||||
),
|
||||
media_source.PlayMedia(
|
||||
url="http://example.com/context.txt",
|
||||
mime_type="text/plain",
|
||||
path=Path("context.txt"),
|
||||
),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"google.genai.files.Files.upload",
|
||||
side_effect=[file1, file2],
|
||||
) as mock_upload,
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
attachments=[
|
||||
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||
{"media_content_id": "media-source://media/context.txt"},
|
||||
],
|
||||
)
|
||||
|
||||
outgoing_message = mock_send_message_stream.mock_calls[1][2]["message"]
|
||||
assert outgoing_message == ["Test prompt", file1, file2]
|
||||
|
||||
assert result.data == "Hi there!"
|
||||
assert len(mock_upload.mock_calls) == 2
|
||||
assert mock_upload.mock_calls[0][2]["file"] == Path("doorbell_snapshot.jpg")
|
||||
assert mock_upload.mock_calls[1][2]["file"] == Path("context.txt")
|
||||
|
||||
# Test attachments require play media with a path
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
side_effect=[
|
||||
media_source.PlayMedia(
|
||||
url="http://example.com/doorbell_snapshot.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.raises(
|
||||
HomeAssistantError, match="Only local attachments are currently supported"
|
||||
),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
attachments=[
|
||||
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with structure
|
||||
mock_send_message_stream.return_value = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
@ -97,7 +185,7 @@ async def test_generate_data(
|
||||
)
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
|
||||
assert len(mock_chat_create.mock_calls) == 2
|
||||
assert len(mock_chat_create.mock_calls) == 4
|
||||
config = mock_chat_create.mock_calls[-1][2]["config"]
|
||||
assert config.response_mime_type == "application/json"
|
||||
assert config.response_schema == {
|
||||
|
@ -87,7 +87,6 @@ async def test_generate_content_service_with_image(
|
||||
),
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data="this is an image")),
|
||||
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||
):
|
||||
response = await hass.services.async_call(
|
||||
|
Loading…
x
Reference in New Issue
Block a user