mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add attachment support to Google Gemini (#148208)
This commit is contained in:
parent
a2220cc2e6
commit
19b3b6cb28
@ -37,7 +37,10 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
):
|
):
|
||||||
"""Google Generative AI AI Task entity."""
|
"""Google Generative AI AI Task entity."""
|
||||||
|
|
||||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
_attr_supported_features = (
|
||||||
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_generate_data(
|
async def _async_generate_data(
|
||||||
self,
|
self,
|
||||||
@ -45,7 +48,7 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> ai_task.GenDataTaskResult:
|
) -> ai_task.GenDataTaskResult:
|
||||||
"""Handle a generate data task."""
|
"""Handle a generate data task."""
|
||||||
await self._async_handle_chat_log(chat_log, task.structure)
|
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)
|
||||||
|
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
|
@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Callable
|
|||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from google.genai import Client
|
from google.genai import Client
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
@ -30,8 +30,8 @@ from google.genai.types import (
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import ai_task, conversation
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, llm
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
@ -60,6 +60,9 @@ from .const import (
|
|||||||
TIMEOUT_MILLIS,
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import GoogleGenerativeAIConfigEntry
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
MAX_TOOL_ITERATIONS = 10
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
@ -313,7 +316,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
entry: ConfigEntry,
|
entry: GoogleGenerativeAIConfigEntry,
|
||||||
subentry: ConfigSubentry,
|
subentry: ConfigSubentry,
|
||||||
default_model: str = RECOMMENDED_CHAT_MODEL,
|
default_model: str = RECOMMENDED_CHAT_MODEL,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -335,6 +338,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
self,
|
self,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
|
attachments: list[ai_task.PlayMediaWithId] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@ -438,6 +442,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
user_message = chat_log.content[-1]
|
user_message = chat_log.content[-1]
|
||||||
assert isinstance(user_message, conversation.UserContent)
|
assert isinstance(user_message, conversation.UserContent)
|
||||||
chat_request: str | list[Part] = user_message.content
|
chat_request: str | list[Part] = user_message.content
|
||||||
|
if attachments:
|
||||||
|
if any(a.path is None for a in attachments):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Only local attachments are currently supported"
|
||||||
|
)
|
||||||
|
files = await async_prepare_files_for_prompt(
|
||||||
|
self.hass,
|
||||||
|
self._genai_client,
|
||||||
|
[a.path for a in attachments], # type: ignore[misc]
|
||||||
|
)
|
||||||
|
chat_request = [chat_request, *files]
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
@ -508,7 +524,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
async def async_prepare_files_for_prompt(
|
async def async_prepare_files_for_prompt(
|
||||||
hass: HomeAssistant, client: Client, files: list[Path]
|
hass: HomeAssistant, client: Client, files: list[Path]
|
||||||
) -> list[File]:
|
) -> list[File]:
|
||||||
"""Append files to a prompt.
|
"""Upload files so they can be attached to a prompt.
|
||||||
|
|
||||||
Caller needs to ensure that the files are allowed.
|
Caller needs to ensure that the files are allowed.
|
||||||
"""
|
"""
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from google.genai.types import GenerateContentResponse
|
from google.genai.types import File, FileState, GenerateContentResponse
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import ai_task
|
from homeassistant.components import ai_task, media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity_registry as er, selector
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
@ -64,6 +65,93 @@ async def test_generate_data(
|
|||||||
)
|
)
|
||||||
assert result.data == "Hi there!"
|
assert result.data == "Hi there!"
|
||||||
|
|
||||||
|
# Test with attachments
|
||||||
|
mock_send_message_stream.return_value = [
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [{"text": "Hi there!"}],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
file1 = File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE)
|
||||||
|
file2 = File(name="context.txt", state=FileState.ACTIVE)
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
side_effect=[
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/doorbell_snapshot.jpg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
path=Path("doorbell_snapshot.jpg"),
|
||||||
|
),
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/context.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
path=Path("context.txt"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"google.genai.files.Files.upload",
|
||||||
|
side_effect=[file1, file2],
|
||||||
|
) as mock_upload,
|
||||||
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
|
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||||
|
):
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Test prompt",
|
||||||
|
attachments=[
|
||||||
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||||
|
{"media_content_id": "media-source://media/context.txt"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoing_message = mock_send_message_stream.mock_calls[1][2]["message"]
|
||||||
|
assert outgoing_message == ["Test prompt", file1, file2]
|
||||||
|
|
||||||
|
assert result.data == "Hi there!"
|
||||||
|
assert len(mock_upload.mock_calls) == 2
|
||||||
|
assert mock_upload.mock_calls[0][2]["file"] == Path("doorbell_snapshot.jpg")
|
||||||
|
assert mock_upload.mock_calls[1][2]["file"] == Path("context.txt")
|
||||||
|
|
||||||
|
# Test attachments require play media with a path
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
side_effect=[
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/doorbell_snapshot.jpg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
path=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
pytest.raises(
|
||||||
|
HomeAssistantError, match="Only local attachments are currently supported"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Test prompt",
|
||||||
|
attachments=[
|
||||||
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structure
|
||||||
mock_send_message_stream.return_value = [
|
mock_send_message_stream.return_value = [
|
||||||
[
|
[
|
||||||
GenerateContentResponse(
|
GenerateContentResponse(
|
||||||
@ -97,7 +185,7 @@ async def test_generate_data(
|
|||||||
)
|
)
|
||||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||||
|
|
||||||
assert len(mock_chat_create.mock_calls) == 2
|
assert len(mock_chat_create.mock_calls) == 4
|
||||||
config = mock_chat_create.mock_calls[-1][2]["config"]
|
config = mock_chat_create.mock_calls[-1][2]["config"]
|
||||||
assert config.response_mime_type == "application/json"
|
assert config.response_mime_type == "application/json"
|
||||||
assert config.response_schema == {
|
assert config.response_schema == {
|
||||||
|
@ -87,7 +87,6 @@ async def test_generate_content_service_with_image(
|
|||||||
),
|
),
|
||||||
patch("pathlib.Path.exists", return_value=True),
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
patch("builtins.open", mock_open(read_data="this is an image")),
|
|
||||||
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||||
):
|
):
|
||||||
response = await hass.services.async_call(
|
response = await hass.services.async_call(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user