mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 14:27:07 +00:00
* Added unit tests * Addressed review comments * Fixed tests * PR comments
This commit is contained in:
parent
212c3ddcca
commit
0edfbded23
@ -2,11 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from google.genai.types import File, FileState
|
||||
from requests.exceptions import Timeout
|
||||
import voluptuous as vol
|
||||
|
||||
@ -32,6 +34,8 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
@ -91,8 +95,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
)
|
||||
prompt_parts.append(uploaded_file)
|
||||
|
||||
async def wait_for_file_processing(uploaded_file: File) -> None:
|
||||
"""Wait for file processing to complete."""
|
||||
while True:
|
||||
uploaded_file = await client.aio.files.get(
|
||||
name=uploaded_file.name,
|
||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||
)
|
||||
if uploaded_file.state not in (
|
||||
FileState.STATE_UNSPECIFIED,
|
||||
FileState.PROCESSING,
|
||||
):
|
||||
break
|
||||
LOGGER.debug(
|
||||
"Waiting for file `%s` to be processed, current state: %s",
|
||||
uploaded_file.name,
|
||||
uploaded_file.state,
|
||||
)
|
||||
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||
|
||||
if uploaded_file.state == FileState.FAILED:
|
||||
raise HomeAssistantError(
|
||||
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
||||
)
|
||||
|
||||
await hass.async_add_executor_job(append_files_to_prompt)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(wait_for_file_processing(part))
|
||||
for part in prompt_parts
|
||||
if isinstance(part, File) and part.state != FileState.ACTIVE
|
||||
]
|
||||
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content(
|
||||
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
|
||||
|
@ -26,3 +26,4 @@ CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||
|
@ -1,4 +1,21 @@
|
||||
# serializer version: 1
|
||||
# name: test_generate_content_file_processing_succeeds
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'contents': list([
|
||||
'Describe this image from my doorbell camera',
|
||||
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.PROCESSING: 'PROCESSING'>, source=None, video_metadata=None, error=None),
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_generate_content_service_with_image
|
||||
list([
|
||||
tuple(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, mock_open, patch
|
||||
|
||||
from google.genai.types import File, FileState
|
||||
import pytest
|
||||
from requests.exceptions import Timeout
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
@ -91,6 +92,117 @@ async def test_generate_content_service_with_image(
|
||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_content_file_processing_succeeds(
|
||||
hass: HomeAssistant, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test generate content service."""
|
||||
stubbed_generated_content = (
|
||||
"A mail carrier is at your front door delivering a package"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
return_value=Mock(
|
||||
text=stubbed_generated_content,
|
||||
prompt_feedback=None,
|
||||
candidates=[Mock()],
|
||||
),
|
||||
) as mock_generate,
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data="this is an image")),
|
||||
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||
patch(
|
||||
"google.genai.files.Files.upload",
|
||||
side_effect=[
|
||||
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||
File(name="context.txt", state=FileState.PROCESSING),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"google.genai.files.AsyncFiles.get",
|
||||
side_effect=[
|
||||
File(name="context.txt", state=FileState.PROCESSING),
|
||||
File(name="context.txt", state=FileState.ACTIVE),
|
||||
],
|
||||
),
|
||||
):
|
||||
response = await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{
|
||||
"prompt": "Describe this image from my doorbell camera",
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"text": stubbed_generated_content,
|
||||
}
|
||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_content_file_processing_fails(
|
||||
hass: HomeAssistant, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test generate content service."""
|
||||
stubbed_generated_content = (
|
||||
"A mail carrier is at your front door delivering a package"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
return_value=Mock(
|
||||
text=stubbed_generated_content,
|
||||
prompt_feedback=None,
|
||||
candidates=[Mock()],
|
||||
),
|
||||
),
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data="this is an image")),
|
||||
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||
patch(
|
||||
"google.genai.files.Files.upload",
|
||||
side_effect=[
|
||||
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||
File(name="context.txt", state=FileState.PROCESSING),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"google.genai.files.AsyncFiles.get",
|
||||
side_effect=[
|
||||
File(name="context.txt", state=FileState.PROCESSING),
|
||||
File(
|
||||
name="context.txt",
|
||||
state=FileState.FAILED,
|
||||
error={"message": "File processing failed"},
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="File `context.txt` processing failed, reason: File processing failed",
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{
|
||||
"prompt": "Describe this image from my doorbell camera",
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_content_service_error(
|
||||
hass: HomeAssistant,
|
||||
|
Loading…
x
Reference in New Issue
Block a user