mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Extract files_to_prompt from Gemini action (#148203)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
parent
075efb469a
commit
8cb9cadce9
@ -2,15 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from google.genai.types import File, FileState
|
||||
from requests.exceptions import Timeout
|
||||
import voluptuous as vol
|
||||
|
||||
@ -42,13 +39,13 @@ from .const import (
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
from .entity import async_prepare_files_for_prompt
|
||||
|
||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||
CONF_IMAGE_FILENAME = "image_filename"
|
||||
@ -92,58 +89,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
client = config_entry.runtime_data
|
||||
|
||||
def append_files_to_prompt():
|
||||
image_filenames = call.data[CONF_IMAGE_FILENAME]
|
||||
filenames = call.data[CONF_FILENAMES]
|
||||
for filename in set(image_filenames + filenames):
|
||||
files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES]
|
||||
|
||||
if files:
|
||||
for filename in files:
|
||||
if not hass.config.is_allowed_path(filename):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot read `{filename}`, no access to path; "
|
||||
"`allowlist_external_dirs` may need to be adjusted in "
|
||||
"`configuration.yaml`"
|
||||
)
|
||||
if not Path(filename).exists():
|
||||
raise HomeAssistantError(f"`{filename}` does not exist")
|
||||
mimetype = mimetypes.guess_type(filename)[0]
|
||||
with open(filename, "rb") as file:
|
||||
uploaded_file = client.files.upload(
|
||||
file=file, config={"mime_type": mimetype}
|
||||
)
|
||||
prompt_parts.append(uploaded_file)
|
||||
|
||||
async def wait_for_file_processing(uploaded_file: File) -> None:
|
||||
"""Wait for file processing to complete."""
|
||||
while True:
|
||||
uploaded_file = await client.aio.files.get(
|
||||
name=uploaded_file.name,
|
||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||
prompt_parts.extend(
|
||||
await async_prepare_files_for_prompt(
|
||||
hass, client, [Path(filename) for filename in files]
|
||||
)
|
||||
if uploaded_file.state not in (
|
||||
FileState.STATE_UNSPECIFIED,
|
||||
FileState.PROCESSING,
|
||||
):
|
||||
break
|
||||
LOGGER.debug(
|
||||
"Waiting for file `%s` to be processed, current state: %s",
|
||||
uploaded_file.name,
|
||||
uploaded_file.state,
|
||||
)
|
||||
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||
|
||||
if uploaded_file.state == FileState.FAILED:
|
||||
raise HomeAssistantError(
|
||||
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
||||
)
|
||||
|
||||
await hass.async_add_executor_job(append_files_to_prompt)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(wait_for_file_processing(part))
|
||||
for part in prompt_parts
|
||||
if isinstance(part, File) and part.state != FileState.ACTIVE
|
||||
]
|
||||
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
||||
await asyncio.gather(*tasks)
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content(
|
||||
|
@ -2,15 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import replace
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from google.genai.types import (
|
||||
AutomaticFunctionCallingConfig,
|
||||
Content,
|
||||
File,
|
||||
FileState,
|
||||
FunctionDeclaration,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
@ -26,6 +32,7 @@ from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
@ -42,6 +49,7 @@ from .const import (
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
@ -49,6 +57,7 @@ from .const import (
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
@ -494,3 +503,68 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def async_prepare_files_for_prompt(
|
||||
hass: HomeAssistant, client: Client, files: list[Path]
|
||||
) -> list[File]:
|
||||
"""Append files to a prompt.
|
||||
|
||||
Caller needs to ensure that the files are allowed.
|
||||
"""
|
||||
|
||||
def upload_files() -> list[File]:
|
||||
prompt_parts: list[File] = []
|
||||
for filename in files:
|
||||
if not filename.exists():
|
||||
raise HomeAssistantError(f"`{filename}` does not exist")
|
||||
mimetype = mimetypes.guess_type(filename)[0]
|
||||
prompt_parts.append(
|
||||
client.files.upload(
|
||||
file=filename,
|
||||
config={
|
||||
"mime_type": mimetype,
|
||||
"display_name": filename.name,
|
||||
},
|
||||
)
|
||||
)
|
||||
return prompt_parts
|
||||
|
||||
async def wait_for_file_processing(uploaded_file: File) -> None:
|
||||
"""Wait for file processing to complete."""
|
||||
first = True
|
||||
while uploaded_file.state in (
|
||||
FileState.STATE_UNSPECIFIED,
|
||||
FileState.PROCESSING,
|
||||
):
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
LOGGER.debug(
|
||||
"Waiting for file `%s` to be processed, current state: %s",
|
||||
uploaded_file.name,
|
||||
uploaded_file.state,
|
||||
)
|
||||
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||
|
||||
uploaded_file = await client.aio.files.get(
|
||||
name=uploaded_file.name,
|
||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||
)
|
||||
|
||||
if uploaded_file.state == FileState.FAILED:
|
||||
raise HomeAssistantError(
|
||||
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
||||
)
|
||||
|
||||
prompt_parts = await hass.async_add_executor_job(upload_files)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(wait_for_file_processing(part))
|
||||
for part in prompt_parts
|
||||
if part.state != FileState.ACTIVE
|
||||
]
|
||||
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return prompt_parts
|
||||
|
@ -122,8 +122,8 @@
|
||||
dict({
|
||||
'contents': list([
|
||||
'Describe this image from my doorbell camera',
|
||||
b'some file',
|
||||
b'some file',
|
||||
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||
]),
|
||||
'model': 'models/gemini-2.5-flash',
|
||||
}),
|
||||
|
@ -80,7 +80,10 @@ async def test_generate_content_service_with_image(
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"google.genai.files.Files.upload",
|
||||
return_value=b"some file",
|
||||
side_effect=[
|
||||
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||
File(name="context.txt", state=FileState.ACTIVE),
|
||||
],
|
||||
),
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
@ -92,7 +95,7 @@ async def test_generate_content_service_with_image(
|
||||
"generate_content",
|
||||
{
|
||||
"prompt": "Describe this image from my doorbell camera",
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
@ -146,7 +149,7 @@ async def test_generate_content_file_processing_succeeds(
|
||||
"generate_content",
|
||||
{
|
||||
"prompt": "Describe this image from my doorbell camera",
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
@ -208,7 +211,7 @@ async def test_generate_content_file_processing_fails(
|
||||
"generate_content",
|
||||
{
|
||||
"prompt": "Describe this image from my doorbell camera",
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
||||
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user