mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Extract files_to_prompt from Gemini action (#148203)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
parent
075efb469a
commit
8cb9cadce9
@ -2,15 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import mimetypes
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
from google.genai import Client
|
from google.genai import Client
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
from google.genai.types import File, FileState
|
|
||||||
from requests.exceptions import Timeout
|
from requests.exceptions import Timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -42,13 +39,13 @@ from .const import (
|
|||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
FILE_POLLING_INTERVAL_SECONDS,
|
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_AI_TASK_OPTIONS,
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_TTS_OPTIONS,
|
RECOMMENDED_TTS_OPTIONS,
|
||||||
TIMEOUT_MILLIS,
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
|
from .entity import async_prepare_files_for_prompt
|
||||||
|
|
||||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||||
CONF_IMAGE_FILENAME = "image_filename"
|
CONF_IMAGE_FILENAME = "image_filename"
|
||||||
@ -92,58 +89,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
|
|
||||||
client = config_entry.runtime_data
|
client = config_entry.runtime_data
|
||||||
|
|
||||||
def append_files_to_prompt():
|
files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES]
|
||||||
image_filenames = call.data[CONF_IMAGE_FILENAME]
|
|
||||||
filenames = call.data[CONF_FILENAMES]
|
if files:
|
||||||
for filename in set(image_filenames + filenames):
|
for filename in files:
|
||||||
if not hass.config.is_allowed_path(filename):
|
if not hass.config.is_allowed_path(filename):
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Cannot read `{filename}`, no access to path; "
|
f"Cannot read `{filename}`, no access to path; "
|
||||||
"`allowlist_external_dirs` may need to be adjusted in "
|
"`allowlist_external_dirs` may need to be adjusted in "
|
||||||
"`configuration.yaml`"
|
"`configuration.yaml`"
|
||||||
)
|
)
|
||||||
if not Path(filename).exists():
|
|
||||||
raise HomeAssistantError(f"`{filename}` does not exist")
|
|
||||||
mimetype = mimetypes.guess_type(filename)[0]
|
|
||||||
with open(filename, "rb") as file:
|
|
||||||
uploaded_file = client.files.upload(
|
|
||||||
file=file, config={"mime_type": mimetype}
|
|
||||||
)
|
|
||||||
prompt_parts.append(uploaded_file)
|
|
||||||
|
|
||||||
async def wait_for_file_processing(uploaded_file: File) -> None:
|
prompt_parts.extend(
|
||||||
"""Wait for file processing to complete."""
|
await async_prepare_files_for_prompt(
|
||||||
while True:
|
hass, client, [Path(filename) for filename in files]
|
||||||
uploaded_file = await client.aio.files.get(
|
|
||||||
name=uploaded_file.name,
|
|
||||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
|
||||||
)
|
)
|
||||||
if uploaded_file.state not in (
|
)
|
||||||
FileState.STATE_UNSPECIFIED,
|
|
||||||
FileState.PROCESSING,
|
|
||||||
):
|
|
||||||
break
|
|
||||||
LOGGER.debug(
|
|
||||||
"Waiting for file `%s` to be processed, current state: %s",
|
|
||||||
uploaded_file.name,
|
|
||||||
uploaded_file.state,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
|
||||||
|
|
||||||
if uploaded_file.state == FileState.FAILED:
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await hass.async_add_executor_job(append_files_to_prompt)
|
|
||||||
|
|
||||||
tasks = [
|
|
||||||
asyncio.create_task(wait_for_file_processing(part))
|
|
||||||
for part in prompt_parts
|
|
||||||
if isinstance(part, File) and part.state != FileState.ACTIVE
|
|
||||||
]
|
|
||||||
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.aio.models.generate_content(
|
response = await client.aio.models.generate_content(
|
||||||
|
@ -2,15 +2,21 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from google.genai import Client
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
from google.genai.types import (
|
from google.genai.types import (
|
||||||
AutomaticFunctionCallingConfig,
|
AutomaticFunctionCallingConfig,
|
||||||
Content,
|
Content,
|
||||||
|
File,
|
||||||
|
FileState,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
@ -26,6 +32,7 @@ from voluptuous_openapi import convert
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, llm
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
@ -42,6 +49,7 @@ from .const import (
|
|||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
FILE_POLLING_INTERVAL_SECONDS,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
@ -49,6 +57,7 @@ from .const import (
|
|||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_K,
|
RECOMMENDED_TOP_K,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
@ -494,3 +503,68 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_prepare_files_for_prompt(
|
||||||
|
hass: HomeAssistant, client: Client, files: list[Path]
|
||||||
|
) -> list[File]:
|
||||||
|
"""Append files to a prompt.
|
||||||
|
|
||||||
|
Caller needs to ensure that the files are allowed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def upload_files() -> list[File]:
|
||||||
|
prompt_parts: list[File] = []
|
||||||
|
for filename in files:
|
||||||
|
if not filename.exists():
|
||||||
|
raise HomeAssistantError(f"`{filename}` does not exist")
|
||||||
|
mimetype = mimetypes.guess_type(filename)[0]
|
||||||
|
prompt_parts.append(
|
||||||
|
client.files.upload(
|
||||||
|
file=filename,
|
||||||
|
config={
|
||||||
|
"mime_type": mimetype,
|
||||||
|
"display_name": filename.name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return prompt_parts
|
||||||
|
|
||||||
|
async def wait_for_file_processing(uploaded_file: File) -> None:
|
||||||
|
"""Wait for file processing to complete."""
|
||||||
|
first = True
|
||||||
|
while uploaded_file.state in (
|
||||||
|
FileState.STATE_UNSPECIFIED,
|
||||||
|
FileState.PROCESSING,
|
||||||
|
):
|
||||||
|
if first:
|
||||||
|
first = False
|
||||||
|
else:
|
||||||
|
LOGGER.debug(
|
||||||
|
"Waiting for file `%s` to be processed, current state: %s",
|
||||||
|
uploaded_file.name,
|
||||||
|
uploaded_file.state,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
uploaded_file = await client.aio.files.get(
|
||||||
|
name=uploaded_file.name,
|
||||||
|
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_file.state == FileState.FAILED:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_parts = await hass.async_add_executor_job(upload_files)
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(wait_for_file_processing(part))
|
||||||
|
for part in prompt_parts
|
||||||
|
if part.state != FileState.ACTIVE
|
||||||
|
]
|
||||||
|
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
return prompt_parts
|
||||||
|
@ -122,8 +122,8 @@
|
|||||||
dict({
|
dict({
|
||||||
'contents': list([
|
'contents': list([
|
||||||
'Describe this image from my doorbell camera',
|
'Describe this image from my doorbell camera',
|
||||||
b'some file',
|
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||||
b'some file',
|
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
||||||
]),
|
]),
|
||||||
'model': 'models/gemini-2.5-flash',
|
'model': 'models/gemini-2.5-flash',
|
||||||
}),
|
}),
|
||||||
|
@ -80,7 +80,10 @@ async def test_generate_content_service_with_image(
|
|||||||
) as mock_generate,
|
) as mock_generate,
|
||||||
patch(
|
patch(
|
||||||
"google.genai.files.Files.upload",
|
"google.genai.files.Files.upload",
|
||||||
return_value=b"some file",
|
side_effect=[
|
||||||
|
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
|
||||||
|
File(name="context.txt", state=FileState.ACTIVE),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
patch("pathlib.Path.exists", return_value=True),
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
@ -92,7 +95,7 @@ async def test_generate_content_service_with_image(
|
|||||||
"generate_content",
|
"generate_content",
|
||||||
{
|
{
|
||||||
"prompt": "Describe this image from my doorbell camera",
|
"prompt": "Describe this image from my doorbell camera",
|
||||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||||
},
|
},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
return_response=True,
|
return_response=True,
|
||||||
@ -146,7 +149,7 @@ async def test_generate_content_file_processing_succeeds(
|
|||||||
"generate_content",
|
"generate_content",
|
||||||
{
|
{
|
||||||
"prompt": "Describe this image from my doorbell camera",
|
"prompt": "Describe this image from my doorbell camera",
|
||||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||||
},
|
},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
return_response=True,
|
return_response=True,
|
||||||
@ -208,7 +211,7 @@ async def test_generate_content_file_processing_fails(
|
|||||||
"generate_content",
|
"generate_content",
|
||||||
{
|
{
|
||||||
"prompt": "Describe this image from my doorbell camera",
|
"prompt": "Describe this image from my doorbell camera",
|
||||||
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
|
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
|
||||||
},
|
},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
return_response=True,
|
return_response=True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user