Fixes #140182 by checking file status before sending the prompt. (#144131)

* Added unit tests

* Addressed review comments

* Fixed tests

* PR comments
This commit is contained in:
Ivan Lopez Hernandez 2025-05-05 23:45:39 -07:00 committed by GitHub
parent 212c3ddcca
commit 0edfbded23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 0 deletions

View File

@ -2,11 +2,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from google.genai import Client from google.genai import Client
from google.genai.errors import APIError, ClientError from google.genai.errors import APIError, ClientError
from google.genai.types import File, FileState
from requests.exceptions import Timeout from requests.exceptions import Timeout
import voluptuous as vol import voluptuous as vol
@ -32,6 +34,8 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_PROMPT, CONF_PROMPT,
DOMAIN, DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
TIMEOUT_MILLIS, TIMEOUT_MILLIS,
) )
@ -91,8 +95,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
) )
prompt_parts.append(uploaded_file) prompt_parts.append(uploaded_file)
async def wait_for_file_processing(uploaded_file: File) -> None:
"""Wait for file processing to complete."""
while True:
uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)
if uploaded_file.state not in (
FileState.STATE_UNSPECIFIED,
FileState.PROCESSING,
):
break
LOGGER.debug(
"Waiting for file `%s` to be processed, current state: %s",
uploaded_file.name,
uploaded_file.state,
)
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
)
await hass.async_add_executor_job(append_files_to_prompt) await hass.async_add_executor_job(append_files_to_prompt)
tasks = [
asyncio.create_task(wait_for_file_processing(part))
for part in prompt_parts
if isinstance(part, File) and part.state != FileState.ACTIVE
]
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
await asyncio.gather(*tasks)
try: try:
response = await client.aio.models.generate_content( response = await client.aio.models.generate_content(
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts

View File

@ -26,3 +26,4 @@ CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
TIMEOUT_MILLIS = 10000 TIMEOUT_MILLIS = 10000
FILE_POLLING_INTERVAL_SECONDS = 0.05

View File

@ -1,4 +1,21 @@
# serializer version: 1 # serializer version: 1
# name: test_generate_content_file_processing_succeeds
list([
tuple(
'',
tuple(
),
dict({
'contents': list([
'Describe this image from my doorbell camera',
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.PROCESSING: 'PROCESSING'>, source=None, video_metadata=None, error=None),
]),
'model': 'models/gemini-2.0-flash',
}),
),
])
# ---
# name: test_generate_content_service_with_image # name: test_generate_content_service_with_image
list([ list([
tuple( tuple(

View File

@ -2,6 +2,7 @@
from unittest.mock import AsyncMock, Mock, mock_open, patch from unittest.mock import AsyncMock, Mock, mock_open, patch
from google.genai.types import File, FileState
import pytest import pytest
from requests.exceptions import Timeout from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -91,6 +92,117 @@ async def test_generate_content_service_with_image(
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_content_file_processing_succeeds(
hass: HomeAssistant, snapshot: SnapshotAssertion
) -> None:
"""Test generate content service."""
stubbed_generated_content = (
"A mail carrier is at your front door delivering a package"
)
with (
patch(
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
text=stubbed_generated_content,
prompt_feedback=None,
candidates=[Mock()],
),
) as mock_generate,
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("builtins.open", mock_open(read_data="this is an image")),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
patch(
"google.genai.files.Files.upload",
side_effect=[
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
File(name="context.txt", state=FileState.PROCESSING),
],
),
patch(
"google.genai.files.AsyncFiles.get",
side_effect=[
File(name="context.txt", state=FileState.PROCESSING),
File(name="context.txt", state=FileState.ACTIVE),
],
),
):
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
},
blocking=True,
return_response=True,
)
assert response == {
"text": stubbed_generated_content,
}
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_content_file_processing_fails(
hass: HomeAssistant, snapshot: SnapshotAssertion
) -> None:
"""Test generate content service."""
stubbed_generated_content = (
"A mail carrier is at your front door delivering a package"
)
with (
patch(
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
text=stubbed_generated_content,
prompt_feedback=None,
candidates=[Mock()],
),
),
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("builtins.open", mock_open(read_data="this is an image")),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
patch(
"google.genai.files.Files.upload",
side_effect=[
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
File(name="context.txt", state=FileState.PROCESSING),
],
),
patch(
"google.genai.files.AsyncFiles.get",
side_effect=[
File(name="context.txt", state=FileState.PROCESSING),
File(
name="context.txt",
state=FileState.FAILED,
error={"message": "File processing failed"},
),
],
),
pytest.raises(
HomeAssistantError,
match="File `context.txt` processing failed, reason: File processing failed",
),
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
},
blocking=True,
return_response=True,
)
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
async def test_generate_content_service_error( async def test_generate_content_service_error(
hass: HomeAssistant, hass: HomeAssistant,