OpenAI: Extract file attachment logic (#148288)

This commit is contained in:
Paulus Schoutsen 2025-07-08 08:01:49 +02:00 committed by GitHub
parent 73730e3eb3
commit 6d0891e970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 57 deletions

View File

@ -2,8 +2,6 @@
from __future__ import annotations
import base64
from mimetypes import guess_file_type
from pathlib import Path
import openai
@ -11,8 +9,6 @@ from openai.types.images_response import ImagesResponse
from openai.types.responses import (
EasyInputMessageParam,
Response,
ResponseInputFileParam,
ResponseInputImageParam,
ResponseInputMessageContentListParam,
ResponseInputParam,
ResponseInputTextParam,
@ -58,6 +54,7 @@ from .const import (
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
from .entity import async_prepare_files_for_prompt
SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content"
@ -68,15 +65,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
def encode_file(file_path: str) -> tuple[str, str]:
"""Return base64 version of file contents."""
mime_type, _ = guess_file_type(file_path)
if mime_type is None:
mime_type = "application/octet-stream"
with open(file_path, "rb") as image_file:
return (mime_type, base64.b64encode(image_file.read()).decode("utf-8"))
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation."""
await async_migrate_integration(hass)
@ -146,41 +134,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT])
]
def append_files_to_content() -> None:
for filename in call.data[CONF_FILENAMES]:
if filenames := call.data.get(CONF_FILENAMES):
for filename in filenames:
if not hass.config.is_allowed_path(filename):
raise HomeAssistantError(
f"Cannot read `{filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(filename).exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mime_type, base64_file = encode_file(filename)
if "image/" in mime_type:
content.append(
ResponseInputImageParam(
type="input_image",
image_url=f"data:{mime_type};base64,{base64_file}",
detail="auto",
)
)
elif "application/pdf" in mime_type:
content.append(
ResponseInputFileParam(
type="input_file",
filename=filename,
file_data=f"data:{mime_type};base64,{base64_file}",
)
)
else:
raise HomeAssistantError(
"Only images and PDF are supported by the OpenAI API,"
f"`{filename}` is not an image file or PDF"
)
if CONF_FILENAMES in call.data:
await hass.async_add_executor_job(append_files_to_content)
content.extend(
await async_prepare_files_for_prompt(
hass, [Path(filename) for filename in filenames]
)
)
messages: ResponseInputParam = [
EasyInputMessageParam(type="message", role="user", content=content)

View File

@ -1,8 +1,13 @@
"""Base entity for OpenAI."""
from __future__ import annotations
import base64
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
from mimetypes import guess_file_type
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
import openai
from openai._streaming import AsyncStream
@ -17,6 +22,9 @@ from openai.types.responses import (
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseIncompleteEvent,
ResponseInputFileParam,
ResponseInputImageParam,
ResponseInputMessageContentListParam,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
@ -35,11 +43,11 @@ from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OpenAIConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@ -63,6 +71,10 @@ from .const import (
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
if TYPE_CHECKING:
from . import OpenAIConfigEntry
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
@ -312,3 +324,50 @@ class OpenAIBaseLLMEntity(Entity):
if not chat_log.unresponded_tool_results:
break
async def async_prepare_files_for_prompt(
hass: HomeAssistant, files: list[Path]
) -> ResponseInputMessageContentListParam:
"""Append files to a prompt.
Caller needs to ensure that the files are allowed.
"""
def append_files_to_content() -> ResponseInputMessageContentListParam:
content: ResponseInputMessageContentListParam = []
for file_path in files:
if not file_path.exists():
raise HomeAssistantError(f"`{file_path}` does not exist")
mime_type, _ = guess_file_type(file_path)
if not mime_type or not mime_type.startswith(("image/", "application/pdf")):
raise HomeAssistantError(
"Only images and PDF are supported by the OpenAI API,"
f"`{file_path}` is not an image file or PDF"
)
base64_file = base64.b64encode(file_path.read_bytes()).decode("utf-8")
if mime_type.startswith("image/"):
content.append(
ResponseInputImageParam(
type="input_image",
image_url=f"data:{mime_type};base64,{base64_file}",
detail="auto",
)
)
elif mime_type.startswith("application/pdf"):
content.append(
ResponseInputFileParam(
type="input_file",
filename=str(file_path),
file_data=f"data:{mime_type};base64,{base64_file}",
)
)
return content
return await hass.async_add_executor_job(append_files_to_content)

View File

@ -1,6 +1,6 @@
"""Tests for the OpenAI integration."""
from unittest.mock import AsyncMock, mock_open, patch
from unittest.mock import AsyncMock, Mock, mock_open, patch
import httpx
from openai import (
@ -16,7 +16,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL, CONF_FILENAMES
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
from homeassistant.components.openai_conversation.const import DOMAIN
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant
@ -394,7 +394,7 @@ async def test_generate_content_service(
patch(
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
) as mock_b64encode,
patch("builtins.open", mock_open(read_data="ABC")) as mock_file,
patch("pathlib.Path.read_bytes", Mock(return_value=b"ABC")) as mock_file,
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
):
@ -434,15 +434,13 @@ async def test_generate_content_service(
assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][2] == expected_args
assert mock_b64encode.call_count == number_of_files
for idx, file in enumerate(service_data[CONF_FILENAMES]):
assert mock_file.call_args_list[idx][0][0] == file
assert mock_file.call_count == number_of_files
@pytest.mark.parametrize(
(
"service_data",
"error",
"number_of_files",
"exists_side_effect",
"is_allowed_side_effect",
),
@ -450,7 +448,6 @@ async def test_generate_content_service(
(
{"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]},
"`/a/b/c.jpg` does not exist",
0,
[False],
[True],
),
@ -460,14 +457,12 @@ async def test_generate_content_service(
"filenames": ["/a/b/c.jpg", "d/e/f.png"],
},
"Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`",
1,
[True, True],
[True, False],
),
(
{"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.mov"]},
"Only images and PDF are supported by the OpenAI API,`/a/b/c.mov` is not an image file or PDF",
1,
[True],
[True],
),
@ -479,7 +474,6 @@ async def test_generate_content_service_invalid(
mock_init_component,
service_data,
error,
number_of_files,
exists_side_effect,
is_allowed_side_effect,
) -> None:
@ -491,9 +485,7 @@ async def test_generate_content_service_invalid(
"openai.resources.responses.AsyncResponses.create",
new_callable=AsyncMock,
) as mock_create,
patch(
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
) as mock_b64encode,
patch("base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]),
patch("builtins.open", mock_open(read_data="ABC")),
patch("pathlib.Path.exists", side_effect=exists_side_effect),
patch.object(
@ -509,7 +501,6 @@ async def test_generate_content_service_invalid(
return_response=True,
)
assert len(mock_create.mock_calls) == 0
assert mock_b64encode.call_count == number_of_files
@pytest.mark.usefixtures("mock_init_component")