diff --git a/homeassistant/components/google_generative_ai_conversation/ai_task.py b/homeassistant/components/google_generative_ai_conversation/ai_task.py index b4f9d73e38d..80d5a1dfa06 100644 --- a/homeassistant/components/google_generative_ai_conversation/ai_task.py +++ b/homeassistant/components/google_generative_ai_conversation/ai_task.py @@ -37,7 +37,10 @@ class GoogleGenerativeAITaskEntity( ): """Google Generative AI AI Task entity.""" - _attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA + _attr_supported_features = ( + ai_task.AITaskEntityFeature.GENERATE_DATA + | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS + ) async def _async_generate_data( self, @@ -45,7 +48,7 @@ class GoogleGenerativeAITaskEntity( chat_log: conversation.ChatLog, ) -> ai_task.GenDataTaskResult: """Handle a generate data task.""" - await self._async_handle_chat_log(chat_log, task.structure) + await self._async_handle_chat_log(chat_log, task.structure, task.attachments) if not isinstance(chat_log.content[-1], conversation.AssistantContent): LOGGER.error( diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py index 8f8edea18cb..fce1fdd40e7 100644 --- a/homeassistant/components/google_generative_ai_conversation/entity.py +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Callable from dataclasses import replace import mimetypes from pathlib import Path -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from google.genai import Client from google.genai.errors import APIError, ClientError @@ -30,8 +30,8 @@ from google.genai.types import ( import voluptuous as vol from voluptuous_openapi import convert -from homeassistant.components import conversation -from homeassistant.config_entries import ConfigEntry, ConfigSubentry +from homeassistant.components import ai_task, conversation +from homeassistant.config_entries import ConfigSubentry from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, llm @@ -60,6 +60,9 @@ from .const import ( TIMEOUT_MILLIS, ) +if TYPE_CHECKING: + from . import GoogleGenerativeAIConfigEntry + # Max number of back and forth with the LLM to generate a response MAX_TOOL_ITERATIONS = 10 @@ -313,7 +316,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): def __init__( self, - entry: ConfigEntry, + entry: GoogleGenerativeAIConfigEntry, subentry: ConfigSubentry, default_model: str = RECOMMENDED_CHAT_MODEL, ) -> None: @@ -335,6 +338,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): self, chat_log: conversation.ChatLog, structure: vol.Schema | None = None, + attachments: list[ai_task.PlayMediaWithId] | None = None, ) -> None: """Generate an answer for the chat log.""" options = self.subentry.data @@ -438,6 +442,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity): user_message = chat_log.content[-1] assert isinstance(user_message, conversation.UserContent) chat_request: str | list[Part] = user_message.content + if attachments: + if any(a.path is None for a in attachments): + raise HomeAssistantError( + "Only local attachments are currently supported" + ) + files = await async_prepare_files_for_prompt( + self.hass, + self._genai_client, + [a.path for a in attachments], # type: ignore[misc] + ) + chat_request = [chat_request, *files] + # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: @@ -508,7 +524,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): async def async_prepare_files_for_prompt( hass: HomeAssistant, client: Client, files: list[Path] ) -> list[File]: - """Append files to a prompt. + """Upload files so they can be attached to a prompt. Caller needs to ensure that the files are allowed. """ diff --git a/tests/components/google_generative_ai_conversation/test_ai_task.py b/tests/components/google_generative_ai_conversation/test_ai_task.py index b2b44aa1cd6..653b41fcb6e 100644 --- a/tests/components/google_generative_ai_conversation/test_ai_task.py +++ b/tests/components/google_generative_ai_conversation/test_ai_task.py @@ -1,12 +1,13 @@ """Test AI Task platform of Google Generative AI Conversation integration.""" -from unittest.mock import AsyncMock +from pathlib import Path +from unittest.mock import AsyncMock, patch -from google.genai.types import GenerateContentResponse +from google.genai.types import File, FileState, GenerateContentResponse import pytest import voluptuous as vol -from homeassistant.components import ai_task +from homeassistant.components import ai_task, media_source from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er, selector @@ -64,6 +65,93 @@ async def test_generate_data( ) assert result.data == "Hi there!" + # Test with attachments + mock_send_message_stream.return_value = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [{"text": "Hi there!"}], + "role": "model", + }, + } + ], + ), + ], + ] + file1 = File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE) + file2 = File(name="context.txt", state=FileState.ACTIVE) + with ( + patch( + "homeassistant.components.media_source.async_resolve_media", + side_effect=[ + media_source.PlayMedia( + url="http://example.com/doorbell_snapshot.jpg", + mime_type="image/jpeg", + path=Path("doorbell_snapshot.jpg"), + ), + media_source.PlayMedia( + url="http://example.com/context.txt", + mime_type="text/plain", + path=Path("context.txt"), + ), + ], + ), + patch( + "google.genai.files.Files.upload", + side_effect=[file1, file2], + ) as mock_upload, + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + patch("mimetypes.guess_type", return_value=["image/jpeg"]), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Test prompt", + attachments=[ + {"media_content_id": "media-source://media/doorbell_snapshot.jpg"}, + {"media_content_id": "media-source://media/context.txt"}, + ], + ) + + outgoing_message = mock_send_message_stream.mock_calls[1][2]["message"] + assert outgoing_message == ["Test prompt", file1, file2] + + assert result.data == "Hi there!" + assert len(mock_upload.mock_calls) == 2 + assert mock_upload.mock_calls[0][2]["file"] == Path("doorbell_snapshot.jpg") + assert mock_upload.mock_calls[1][2]["file"] == Path("context.txt") + + # Test attachments require play media with a path + with ( + patch( + "homeassistant.components.media_source.async_resolve_media", + side_effect=[ + media_source.PlayMedia( + url="http://example.com/doorbell_snapshot.jpg", + mime_type="image/jpeg", + path=None, + ), + ], + ), + pytest.raises( + HomeAssistantError, match="Only local attachments are currently supported" + ), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Test prompt", + attachments=[ + {"media_content_id": "media-source://media/doorbell_snapshot.jpg"}, + ], + ) + + # Test with structure mock_send_message_stream.return_value = [ [ GenerateContentResponse( @@ -97,7 +185,7 @@ async def test_generate_data( ) assert result.data == {"characters": ["Mario", "Luigi"]} - assert len(mock_chat_create.mock_calls) == 2 + assert len(mock_chat_create.mock_calls) == 4 config = mock_chat_create.mock_calls[-1][2]["config"] assert config.response_mime_type == "application/json" assert config.response_schema == { diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 351895c89fb..351293e7ac0 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -87,7 +87,6 @@ async def test_generate_content_service_with_image( ), patch("pathlib.Path.exists", return_value=True), patch.object(hass.config, "is_allowed_path", return_value=True), - patch("builtins.open", mock_open(read_data="this is an image")), patch("mimetypes.guess_type", return_value=["image/jpeg"]), ): response = await hass.services.async_call(