diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 88a51446cda..79d092a60c3 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -2,11 +2,13 @@ from __future__ import annotations +import asyncio import mimetypes from pathlib import Path from google.genai import Client from google.genai.errors import APIError, ClientError +from google.genai.types import File, FileState from requests.exceptions import Timeout import voluptuous as vol @@ -32,6 +34,8 @@ from .const import ( CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, + FILE_POLLING_INTERVAL_SECONDS, + LOGGER, RECOMMENDED_CHAT_MODEL, TIMEOUT_MILLIS, ) @@ -91,8 +95,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) prompt_parts.append(uploaded_file) + async def wait_for_file_processing(uploaded_file: File) -> None: + """Wait for file processing to complete.""" + while True: + uploaded_file = await client.aio.files.get( + name=uploaded_file.name, + config={"http_options": {"timeout": TIMEOUT_MILLIS}}, + ) + if uploaded_file.state not in ( + FileState.STATE_UNSPECIFIED, + FileState.PROCESSING, + ): + break + LOGGER.debug( + "Waiting for file `%s` to be processed, current state: %s", + uploaded_file.name, + uploaded_file.state, + ) + await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS) + + if uploaded_file.state == FileState.FAILED: + raise HomeAssistantError( + f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}" + ) + await hass.async_add_executor_job(append_files_to_prompt) + tasks = [ + asyncio.create_task(wait_for_file_processing(part)) + for part in prompt_parts + if isinstance(part, File) and part.state != FileState.ACTIVE + ] + async with asyncio.timeout(TIMEOUT_MILLIS / 1000): + await asyncio.gather(*tasks) + try: response = await client.aio.models.generate_content( model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index a7dd584ebee..239b3ff763e 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -26,3 +26,4 @@ CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool" RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False TIMEOUT_MILLIS = 10000 +FILE_POLLING_INTERVAL_SECONDS = 0.05 diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index ce882adf6e6..d8e54b15f61 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -1,4 +1,21 @@ # serializer version: 1 +# name: test_generate_content_file_processing_succeeds + list([ + tuple( + '', + tuple( + ), + dict({ + 'contents': list([ + 'Describe this image from my doorbell camera', + File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=, source=None, video_metadata=None, error=None), + File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=, source=None, video_metadata=None, error=None), + ]), + 'model': 'models/gemini-2.0-flash', + }), + ), + ]) +# --- # name: test_generate_content_service_with_image list([ tuple( diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index a08acc0df3f..94308260f74 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, mock_open, patch +from google.genai.types import File, FileState import pytest from requests.exceptions import Timeout from syrupy.assertion import SnapshotAssertion @@ -91,6 +92,117 @@ async def test_generate_content_service_with_image( assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_content_file_processing_succeeds( + hass: HomeAssistant, snapshot: SnapshotAssertion +) -> None: + """Test generate content service.""" + stubbed_generated_content = ( + "A mail carrier is at your front door delivering a package" + ) + + with ( + patch( + "google.genai.models.AsyncModels.generate_content", + return_value=Mock( + text=stubbed_generated_content, + prompt_feedback=None, + candidates=[Mock()], + ), + ) as mock_generate, + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + patch("builtins.open", mock_open(read_data="this is an image")), + patch("mimetypes.guess_type", return_value=["image/jpeg"]), + patch( + "google.genai.files.Files.upload", + side_effect=[ + File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE), + File(name="context.txt", state=FileState.PROCESSING), + ], + ), + patch( + "google.genai.files.AsyncFiles.get", + side_effect=[ + File(name="context.txt", state=FileState.PROCESSING), + File(name="context.txt", state=FileState.ACTIVE), + ], + ), + ): + response = await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"], + }, + blocking=True, + return_response=True, + ) + + assert response == { + "text": stubbed_generated_content, + } + assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_content_file_processing_fails( + hass: HomeAssistant, snapshot: SnapshotAssertion +) -> None: + """Test generate content service.""" + stubbed_generated_content = ( + "A mail carrier is at your front door delivering a package" + ) + + with ( + patch( + "google.genai.models.AsyncModels.generate_content", + return_value=Mock( + text=stubbed_generated_content, + prompt_feedback=None, + candidates=[Mock()], + ), + ), + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + patch("builtins.open", mock_open(read_data="this is an image")), + patch("mimetypes.guess_type", return_value=["image/jpeg"]), + patch( + "google.genai.files.Files.upload", + side_effect=[ + File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE), + File(name="context.txt", state=FileState.PROCESSING), + ], + ), + patch( + "google.genai.files.AsyncFiles.get", + side_effect=[ + File(name="context.txt", state=FileState.PROCESSING), + File( + name="context.txt", + state=FileState.FAILED, + error={"message": "File processing failed"}, + ), + ], + ), + pytest.raises( + HomeAssistantError, + match="File `context.txt` processing failed, reason: File processing failed", + ), + ): + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"], + }, + blocking=True, + return_response=True, + ) + + @pytest.mark.usefixtures("mock_init_component") async def test_generate_content_service_error( hass: HomeAssistant,