OpenAI: Extract file attachment logic (#148288)

This commit is contained in:
Paulus Schoutsen 2025-07-08 08:01:49 +02:00 committed by GitHub
parent 73730e3eb3
commit 6d0891e970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 57 deletions

View File

@ -2,8 +2,6 @@
from __future__ import annotations from __future__ import annotations
import base64
from mimetypes import guess_file_type
from pathlib import Path from pathlib import Path
import openai import openai
@ -11,8 +9,6 @@ from openai.types.images_response import ImagesResponse
from openai.types.responses import ( from openai.types.responses import (
EasyInputMessageParam, EasyInputMessageParam,
Response, Response,
ResponseInputFileParam,
ResponseInputImageParam,
ResponseInputMessageContentListParam, ResponseInputMessageContentListParam,
ResponseInputParam, ResponseInputParam,
ResponseInputTextParam, ResponseInputTextParam,
@ -58,6 +54,7 @@ from .const import (
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
from .entity import async_prepare_files_for_prompt
SERVICE_GENERATE_IMAGE = "generate_image" SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
@ -68,15 +65,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
def encode_file(file_path: str) -> tuple[str, str]:
"""Return base64 version of file contents."""
mime_type, _ = guess_file_type(file_path)
if mime_type is None:
mime_type = "application/octet-stream"
with open(file_path, "rb") as image_file:
return (mime_type, base64.b64encode(image_file.read()).decode("utf-8"))
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation.""" """Set up OpenAI Conversation."""
await async_migrate_integration(hass) await async_migrate_integration(hass)
@ -146,41 +134,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT]) ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT])
] ]
def append_files_to_content() -> None: if filenames := call.data.get(CONF_FILENAMES):
for filename in call.data[CONF_FILENAMES]: for filename in filenames:
if not hass.config.is_allowed_path(filename): if not hass.config.is_allowed_path(filename):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot read `{filename}`, no access to path; " f"Cannot read `{filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in " "`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`" "`configuration.yaml`"
) )
if not Path(filename).exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mime_type, base64_file = encode_file(filename)
if "image/" in mime_type:
content.append(
ResponseInputImageParam(
type="input_image",
image_url=f"data:{mime_type};base64,{base64_file}",
detail="auto",
)
)
elif "application/pdf" in mime_type:
content.append(
ResponseInputFileParam(
type="input_file",
filename=filename,
file_data=f"data:{mime_type};base64,{base64_file}",
)
)
else:
raise HomeAssistantError(
"Only images and PDF are supported by the OpenAI API,"
f"`{filename}` is not an image file or PDF"
)
if CONF_FILENAMES in call.data: content.extend(
await hass.async_add_executor_job(append_files_to_content) await async_prepare_files_for_prompt(
hass, [Path(filename) for filename in filenames]
)
)
messages: ResponseInputParam = [ messages: ResponseInputParam = [
EasyInputMessageParam(type="message", role="user", content=content) EasyInputMessageParam(type="message", role="user", content=content)

View File

@ -1,8 +1,13 @@
"""Base entity for OpenAI.""" """Base entity for OpenAI."""
from __future__ import annotations
import base64
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
import json import json
from typing import Any, Literal, cast from mimetypes import guess_file_type
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
import openai import openai
from openai._streaming import AsyncStream from openai._streaming import AsyncStream
@ -17,6 +22,9 @@ from openai.types.responses import (
ResponseFunctionToolCall, ResponseFunctionToolCall,
ResponseFunctionToolCallParam, ResponseFunctionToolCallParam,
ResponseIncompleteEvent, ResponseIncompleteEvent,
ResponseInputFileParam,
ResponseInputImageParam,
ResponseInputMessageContentListParam,
ResponseInputParam, ResponseInputParam,
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputItemDoneEvent,
@ -35,11 +43,11 @@ from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from . import OpenAIConfigEntry
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
@ -63,6 +71,10 @@ from .const import (
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
) )
if TYPE_CHECKING:
from . import OpenAIConfigEntry
# Max number of back and forth with the LLM to generate a response # Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10 MAX_TOOL_ITERATIONS = 10
@ -312,3 +324,50 @@ class OpenAIBaseLLMEntity(Entity):
if not chat_log.unresponded_tool_results: if not chat_log.unresponded_tool_results:
break break
async def async_prepare_files_for_prompt(
hass: HomeAssistant, files: list[Path]
) -> ResponseInputMessageContentListParam:
"""Append files to a prompt.
Caller needs to ensure that the files are allowed.
"""
def append_files_to_content() -> ResponseInputMessageContentListParam:
content: ResponseInputMessageContentListParam = []
for file_path in files:
if not file_path.exists():
raise HomeAssistantError(f"`{file_path}` does not exist")
mime_type, _ = guess_file_type(file_path)
if not mime_type or not mime_type.startswith(("image/", "application/pdf")):
raise HomeAssistantError(
"Only images and PDF are supported by the OpenAI API,"
f"`{file_path}` is not an image file or PDF"
)
base64_file = base64.b64encode(file_path.read_bytes()).decode("utf-8")
if mime_type.startswith("image/"):
content.append(
ResponseInputImageParam(
type="input_image",
image_url=f"data:{mime_type};base64,{base64_file}",
detail="auto",
)
)
elif mime_type.startswith("application/pdf"):
content.append(
ResponseInputFileParam(
type="input_file",
filename=str(file_path),
file_data=f"data:{mime_type};base64,{base64_file}",
)
)
return content
return await hass.async_add_executor_job(append_files_to_content)

View File

@ -1,6 +1,6 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from unittest.mock import AsyncMock, mock_open, patch from unittest.mock import AsyncMock, Mock, mock_open, patch
import httpx import httpx
from openai import ( from openai import (
@ -16,7 +16,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props from syrupy.filters import props
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL, CONF_FILENAMES from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
from homeassistant.components.openai_conversation.const import DOMAIN from homeassistant.components.openai_conversation.const import DOMAIN
from homeassistant.config_entries import ConfigSubentryData from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -394,7 +394,7 @@ async def test_generate_content_service(
patch( patch(
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"] "base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
) as mock_b64encode, ) as mock_b64encode,
patch("builtins.open", mock_open(read_data="ABC")) as mock_file, patch("pathlib.Path.read_bytes", Mock(return_value=b"ABC")) as mock_file,
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True), patch.object(hass.config, "is_allowed_path", return_value=True),
): ):
@ -434,15 +434,13 @@ async def test_generate_content_service(
assert len(mock_create.mock_calls) == 1 assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][2] == expected_args assert mock_create.mock_calls[0][2] == expected_args
assert mock_b64encode.call_count == number_of_files assert mock_b64encode.call_count == number_of_files
for idx, file in enumerate(service_data[CONF_FILENAMES]): assert mock_file.call_count == number_of_files
assert mock_file.call_args_list[idx][0][0] == file
@pytest.mark.parametrize( @pytest.mark.parametrize(
( (
"service_data", "service_data",
"error", "error",
"number_of_files",
"exists_side_effect", "exists_side_effect",
"is_allowed_side_effect", "is_allowed_side_effect",
), ),
@ -450,7 +448,6 @@ async def test_generate_content_service(
( (
{"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]}, {"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]},
"`/a/b/c.jpg` does not exist", "`/a/b/c.jpg` does not exist",
0,
[False], [False],
[True], [True],
), ),
@ -460,14 +457,12 @@ async def test_generate_content_service(
"filenames": ["/a/b/c.jpg", "d/e/f.png"], "filenames": ["/a/b/c.jpg", "d/e/f.png"],
}, },
"Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`", "Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`",
1,
[True, True], [True, True],
[True, False], [True, False],
), ),
( (
{"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.mov"]}, {"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.mov"]},
"Only images and PDF are supported by the OpenAI API,`/a/b/c.mov` is not an image file or PDF", "Only images and PDF are supported by the OpenAI API,`/a/b/c.mov` is not an image file or PDF",
1,
[True], [True],
[True], [True],
), ),
@ -479,7 +474,6 @@ async def test_generate_content_service_invalid(
mock_init_component, mock_init_component,
service_data, service_data,
error, error,
number_of_files,
exists_side_effect, exists_side_effect,
is_allowed_side_effect, is_allowed_side_effect,
) -> None: ) -> None:
@ -491,9 +485,7 @@ async def test_generate_content_service_invalid(
"openai.resources.responses.AsyncResponses.create", "openai.resources.responses.AsyncResponses.create",
new_callable=AsyncMock, new_callable=AsyncMock,
) as mock_create, ) as mock_create,
patch( patch("base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]),
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
) as mock_b64encode,
patch("builtins.open", mock_open(read_data="ABC")), patch("builtins.open", mock_open(read_data="ABC")),
patch("pathlib.Path.exists", side_effect=exists_side_effect), patch("pathlib.Path.exists", side_effect=exists_side_effect),
patch.object( patch.object(
@ -509,7 +501,6 @@ async def test_generate_content_service_invalid(
return_response=True, return_response=True,
) )
assert len(mock_create.mock_calls) == 0 assert len(mock_create.mock_calls) == 0
assert mock_b64encode.call_count == number_of_files
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")