Extract files_to_prompt from Gemini action (#148203)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
Paulus Schoutsen 2025-07-06 15:15:38 +02:00 committed by GitHub
parent 075efb469a
commit 8cb9cadce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 54 deletions

View File

@ -2,15 +2,12 @@
from __future__ import annotations
import asyncio
from functools import partial
import mimetypes
from pathlib import Path
from types import MappingProxyType
from google.genai import Client
from google.genai.errors import APIError, ClientError
from google.genai.types import File, FileState
from requests.exceptions import Timeout
import voluptuous as vol
@ -42,13 +39,13 @@ from .const import (
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS,
)
from .entity import async_prepare_files_for_prompt
SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
@ -92,58 +89,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
client = config_entry.runtime_data
def append_files_to_prompt():
image_filenames = call.data[CONF_IMAGE_FILENAME]
filenames = call.data[CONF_FILENAMES]
for filename in set(image_filenames + filenames):
files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES]
if files:
for filename in files:
if not hass.config.is_allowed_path(filename):
raise HomeAssistantError(
f"Cannot read `{filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(filename).exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mimetype = mimetypes.guess_type(filename)[0]
with open(filename, "rb") as file:
uploaded_file = client.files.upload(
file=file, config={"mime_type": mimetype}
)
prompt_parts.append(uploaded_file)
async def wait_for_file_processing(uploaded_file: File) -> None:
"""Wait for file processing to complete."""
while True:
uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
prompt_parts.extend(
await async_prepare_files_for_prompt(
hass, client, [Path(filename) for filename in files]
)
if uploaded_file.state not in (
FileState.STATE_UNSPECIFIED,
FileState.PROCESSING,
):
break
LOGGER.debug(
"Waiting for file `%s` to be processed, current state: %s",
uploaded_file.name,
uploaded_file.state,
)
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
)
await hass.async_add_executor_job(append_files_to_prompt)
tasks = [
asyncio.create_task(wait_for_file_processing(part))
for part in prompt_parts
if isinstance(part, File) and part.state != FileState.ACTIVE
]
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
await asyncio.gather(*tasks)
)
try:
response = await client.aio.models.generate_content(

View File

@ -2,15 +2,21 @@
from __future__ import annotations
import asyncio
import codecs
from collections.abc import AsyncGenerator, Callable
from dataclasses import replace
import mimetypes
from pathlib import Path
from typing import Any, cast
from google.genai import Client
from google.genai.errors import APIError, ClientError
from google.genai.types import (
AutomaticFunctionCallingConfig,
Content,
File,
FileState,
FunctionDeclaration,
GenerateContentConfig,
GenerateContentResponse,
@ -26,6 +32,7 @@ from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
@ -42,6 +49,7 @@ from .const import (
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
@ -49,6 +57,7 @@ from .const import (
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
TIMEOUT_MILLIS,
)
# Max number of back and forth with the LLM to generate a response
@ -494,3 +503,68 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
),
],
)
async def async_prepare_files_for_prompt(
hass: HomeAssistant, client: Client, files: list[Path]
) -> list[File]:
"""Append files to a prompt.
Caller needs to ensure that the files are allowed.
"""
def upload_files() -> list[File]:
prompt_parts: list[File] = []
for filename in files:
if not filename.exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mimetype = mimetypes.guess_type(filename)[0]
prompt_parts.append(
client.files.upload(
file=filename,
config={
"mime_type": mimetype,
"display_name": filename.name,
},
)
)
return prompt_parts
async def wait_for_file_processing(uploaded_file: File) -> None:
"""Wait for file processing to complete."""
first = True
while uploaded_file.state in (
FileState.STATE_UNSPECIFIED,
FileState.PROCESSING,
):
if first:
first = False
else:
LOGGER.debug(
"Waiting for file `%s` to be processed, current state: %s",
uploaded_file.name,
uploaded_file.state,
)
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)
if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
)
prompt_parts = await hass.async_add_executor_job(upload_files)
tasks = [
asyncio.create_task(wait_for_file_processing(part))
for part in prompt_parts
if part.state != FileState.ACTIVE
]
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
await asyncio.gather(*tasks)
return prompt_parts

View File

@ -122,8 +122,8 @@
dict({
'contents': list([
'Describe this image from my doorbell camera',
b'some file',
b'some file',
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
]),
'model': 'models/gemini-2.5-flash',
}),

View File

@ -80,7 +80,10 @@ async def test_generate_content_service_with_image(
) as mock_generate,
patch(
"google.genai.files.Files.upload",
return_value=b"some file",
side_effect=[
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
File(name="context.txt", state=FileState.ACTIVE),
],
),
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
@ -92,7 +95,7 @@ async def test_generate_content_service_with_image(
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
},
blocking=True,
return_response=True,
@ -146,7 +149,7 @@ async def test_generate_content_file_processing_succeeds(
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
},
blocking=True,
return_response=True,
@ -208,7 +211,7 @@ async def test_generate_content_file_processing_fails(
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
},
blocking=True,
return_response=True,