mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 06:17:07 +00:00
* Added unit tests * Addressed review comments * Fixed tests * PR comments
This commit is contained in:
parent
212c3ddcca
commit
0edfbded23
@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from google.genai import Client
|
from google.genai import Client
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
|
from google.genai.types import File, FileState
|
||||||
from requests.exceptions import Timeout
|
from requests.exceptions import Timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -32,6 +34,8 @@ from .const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
FILE_POLLING_INTERVAL_SECONDS,
|
||||||
|
LOGGER,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
TIMEOUT_MILLIS,
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
@ -91,8 +95,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
)
|
)
|
||||||
prompt_parts.append(uploaded_file)
|
prompt_parts.append(uploaded_file)
|
||||||
|
|
||||||
|
async def wait_for_file_processing(uploaded_file: File) -> None:
|
||||||
|
"""Wait for file processing to complete."""
|
||||||
|
while True:
|
||||||
|
uploaded_file = await client.aio.files.get(
|
||||||
|
name=uploaded_file.name,
|
||||||
|
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||||
|
)
|
||||||
|
if uploaded_file.state not in (
|
||||||
|
FileState.STATE_UNSPECIFIED,
|
||||||
|
FileState.PROCESSING,
|
||||||
|
):
|
||||||
|
break
|
||||||
|
LOGGER.debug(
|
||||||
|
"Waiting for file `%s` to be processed, current state: %s",
|
||||||
|
uploaded_file.name,
|
||||||
|
uploaded_file.state,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
if uploaded_file.state == FileState.FAILED:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
await hass.async_add_executor_job(append_files_to_prompt)
|
await hass.async_add_executor_job(append_files_to_prompt)
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(wait_for_file_processing(part))
|
||||||
|
for part in prompt_parts
|
||||||
|
if isinstance(part, File) and part.state != FileState.ACTIVE
|
||||||
|
]
|
||||||
|
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.aio.models.generate_content(
|
response = await client.aio.models.generate_content(
|
||||||
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
|
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
|
||||||
|
@ -26,3 +26,4 @@ CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
|
|||||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||||
|
|
||||||
TIMEOUT_MILLIS = 10000
|
TIMEOUT_MILLIS = 10000
|
||||||
|
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||||
|
@ -1,4 +1,21 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_generate_content_file_processing_succeeds
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'contents': list([
|
||||||
|
'Describe this image from my doorbell camera',
|
||||||
|
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||||
|
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.PROCESSING: 'PROCESSING'>, source=None, video_metadata=None, error=None),
|
||||||
|
]),
|
||||||
|
'model': 'models/gemini-2.0-flash',
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_generate_content_service_with_image
|
# name: test_generate_content_service_with_image
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, mock_open, patch
|
from unittest.mock import AsyncMock, Mock, mock_open, patch
|
||||||
|
|
||||||
|
from google.genai.types import File, FileState
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import Timeout
|
from requests.exceptions import Timeout
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
@ -91,6 +92,117 @@ async def test_generate_content_service_with_image(
|
|||||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_content_file_processing_succeeds(
|
||||||
|
hass: HomeAssistant, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
"""Test generate content service."""
|
||||||
|
stubbed_generated_content = (
|
||||||
|
"A mail carrier is at your front door delivering a package"
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"google.genai.models.AsyncModels.generate_content",
|
||||||
|
return_value=Mock(
|
||||||
|
text=stubbed_generated_content,
|
||||||
|
prompt_feedback=None,
|
||||||
|
candidates=[Mock()],
|
||||||
|
),
|
||||||
|
) as mock_generate,
|
||||||
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
|
patch("builtins.open", mock_open(read_data="this is an image")),
|
||||||
|
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||||
|
patch(
|
||||||
|
"google.genai.files.Files.upload",
|
||||||
|
side_effect=[
|
||||||
|
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||||
|
File(name="context.txt", state=FileState.PROCESSING),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"google.genai.files.AsyncFiles.get",
|
||||||
|
side_effect=[
|
||||||
|
File(name="context.txt", state=FileState.PROCESSING),
|
||||||
|
File(name="context.txt", state=FileState.ACTIVE),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await hass.services.async_call(
|
||||||
|
"google_generative_ai_conversation",
|
||||||
|
"generate_content",
|
||||||
|
{
|
||||||
|
"prompt": "Describe this image from my doorbell camera",
|
||||||
|
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == {
|
||||||
|
"text": stubbed_generated_content,
|
||||||
|
}
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_content_file_processing_fails(
|
||||||
|
hass: HomeAssistant, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
"""Test generate content service."""
|
||||||
|
stubbed_generated_content = (
|
||||||
|
"A mail carrier is at your front door delivering a package"
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"google.genai.models.AsyncModels.generate_content",
|
||||||
|
return_value=Mock(
|
||||||
|
text=stubbed_generated_content,
|
||||||
|
prompt_feedback=None,
|
||||||
|
candidates=[Mock()],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
|
patch("builtins.open", mock_open(read_data="this is an image")),
|
||||||
|
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
|
||||||
|
patch(
|
||||||
|
"google.genai.files.Files.upload",
|
||||||
|
side_effect=[
|
||||||
|
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||||
|
File(name="context.txt", state=FileState.PROCESSING),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"google.genai.files.AsyncFiles.get",
|
||||||
|
side_effect=[
|
||||||
|
File(name="context.txt", state=FileState.PROCESSING),
|
||||||
|
File(
|
||||||
|
name="context.txt",
|
||||||
|
state=FileState.FAILED,
|
||||||
|
error={"message": "File processing failed"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match="File `context.txt` processing failed, reason: File processing failed",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"google_generative_ai_conversation",
|
||||||
|
"generate_content",
|
||||||
|
{
|
||||||
|
"prompt": "Describe this image from my doorbell camera",
|
||||||
|
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_generate_content_service_error(
|
async def test_generate_content_service_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user