mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
OpenAI: Extract file attachment logic (#148288)
This commit is contained in:
parent
73730e3eb3
commit
6d0891e970
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
|
||||||
from mimetypes import guess_file_type
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@ -11,8 +9,6 @@ from openai.types.images_response import ImagesResponse
|
|||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
EasyInputMessageParam,
|
EasyInputMessageParam,
|
||||||
Response,
|
Response,
|
||||||
ResponseInputFileParam,
|
|
||||||
ResponseInputImageParam,
|
|
||||||
ResponseInputMessageContentListParam,
|
ResponseInputMessageContentListParam,
|
||||||
ResponseInputParam,
|
ResponseInputParam,
|
||||||
ResponseInputTextParam,
|
ResponseInputTextParam,
|
||||||
@ -58,6 +54,7 @@ from .const import (
|
|||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
from .entity import async_prepare_files_for_prompt
|
||||||
|
|
||||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||||
@ -68,15 +65,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|||||||
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||||
|
|
||||||
|
|
||||||
def encode_file(file_path: str) -> tuple[str, str]:
|
|
||||||
"""Return base64 version of file contents."""
|
|
||||||
mime_type, _ = guess_file_type(file_path)
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
with open(file_path, "rb") as image_file:
|
|
||||||
return (mime_type, base64.b64encode(image_file.read()).decode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up OpenAI Conversation."""
|
"""Set up OpenAI Conversation."""
|
||||||
await async_migrate_integration(hass)
|
await async_migrate_integration(hass)
|
||||||
@ -146,41 +134,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT])
|
ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT])
|
||||||
]
|
]
|
||||||
|
|
||||||
def append_files_to_content() -> None:
|
if filenames := call.data.get(CONF_FILENAMES):
|
||||||
for filename in call.data[CONF_FILENAMES]:
|
for filename in filenames:
|
||||||
if not hass.config.is_allowed_path(filename):
|
if not hass.config.is_allowed_path(filename):
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Cannot read `{filename}`, no access to path; "
|
f"Cannot read `{filename}`, no access to path; "
|
||||||
"`allowlist_external_dirs` may need to be adjusted in "
|
"`allowlist_external_dirs` may need to be adjusted in "
|
||||||
"`configuration.yaml`"
|
"`configuration.yaml`"
|
||||||
)
|
)
|
||||||
if not Path(filename).exists():
|
|
||||||
raise HomeAssistantError(f"`{filename}` does not exist")
|
|
||||||
mime_type, base64_file = encode_file(filename)
|
|
||||||
if "image/" in mime_type:
|
|
||||||
content.append(
|
|
||||||
ResponseInputImageParam(
|
|
||||||
type="input_image",
|
|
||||||
image_url=f"data:{mime_type};base64,{base64_file}",
|
|
||||||
detail="auto",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "application/pdf" in mime_type:
|
|
||||||
content.append(
|
|
||||||
ResponseInputFileParam(
|
|
||||||
type="input_file",
|
|
||||||
filename=filename,
|
|
||||||
file_data=f"data:{mime_type};base64,{base64_file}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HomeAssistantError(
|
|
||||||
"Only images and PDF are supported by the OpenAI API,"
|
|
||||||
f"`{filename}` is not an image file or PDF"
|
|
||||||
)
|
|
||||||
|
|
||||||
if CONF_FILENAMES in call.data:
|
content.extend(
|
||||||
await hass.async_add_executor_job(append_files_to_content)
|
await async_prepare_files_for_prompt(
|
||||||
|
hass, [Path(filename) for filename in filenames]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
messages: ResponseInputParam = [
|
messages: ResponseInputParam = [
|
||||||
EasyInputMessageParam(type="message", role="user", content=content)
|
EasyInputMessageParam(type="message", role="user", content=content)
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
"""Base entity for OpenAI."""
|
"""Base entity for OpenAI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal, cast
|
from mimetypes import guess_file_type
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai._streaming import AsyncStream
|
from openai._streaming import AsyncStream
|
||||||
@ -17,6 +22,9 @@ from openai.types.responses import (
|
|||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
ResponseFunctionToolCallParam,
|
ResponseFunctionToolCallParam,
|
||||||
ResponseIncompleteEvent,
|
ResponseIncompleteEvent,
|
||||||
|
ResponseInputFileParam,
|
||||||
|
ResponseInputImageParam,
|
||||||
|
ResponseInputMessageContentListParam,
|
||||||
ResponseInputParam,
|
ResponseInputParam,
|
||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
ResponseOutputItemDoneEvent,
|
ResponseOutputItemDoneEvent,
|
||||||
@ -35,11 +43,11 @@ from voluptuous_openapi import convert
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, llm
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
|
|
||||||
from . import OpenAIConfigEntry
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
@ -63,6 +71,10 @@ from .const import (
|
|||||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import OpenAIConfigEntry
|
||||||
|
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
MAX_TOOL_ITERATIONS = 10
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
@ -312,3 +324,50 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def async_prepare_files_for_prompt(
|
||||||
|
hass: HomeAssistant, files: list[Path]
|
||||||
|
) -> ResponseInputMessageContentListParam:
|
||||||
|
"""Append files to a prompt.
|
||||||
|
|
||||||
|
Caller needs to ensure that the files are allowed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def append_files_to_content() -> ResponseInputMessageContentListParam:
|
||||||
|
content: ResponseInputMessageContentListParam = []
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
if not file_path.exists():
|
||||||
|
raise HomeAssistantError(f"`{file_path}` does not exist")
|
||||||
|
|
||||||
|
mime_type, _ = guess_file_type(file_path)
|
||||||
|
|
||||||
|
if not mime_type or not mime_type.startswith(("image/", "application/pdf")):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Only images and PDF are supported by the OpenAI API,"
|
||||||
|
f"`{file_path}` is not an image file or PDF"
|
||||||
|
)
|
||||||
|
|
||||||
|
base64_file = base64.b64encode(file_path.read_bytes()).decode("utf-8")
|
||||||
|
|
||||||
|
if mime_type.startswith("image/"):
|
||||||
|
content.append(
|
||||||
|
ResponseInputImageParam(
|
||||||
|
type="input_image",
|
||||||
|
image_url=f"data:{mime_type};base64,{base64_file}",
|
||||||
|
detail="auto",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif mime_type.startswith("application/pdf"):
|
||||||
|
content.append(
|
||||||
|
ResponseInputFileParam(
|
||||||
|
type="input_file",
|
||||||
|
filename=str(file_path),
|
||||||
|
file_data=f"data:{mime_type};base64,{base64_file}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
return await hass.async_add_executor_job(append_files_to_content)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Tests for the OpenAI integration."""
|
"""Tests for the OpenAI integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, mock_open, patch
|
from unittest.mock import AsyncMock, Mock, mock_open, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import (
|
from openai import (
|
||||||
@ -16,7 +16,7 @@ import pytest
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
from syrupy.filters import props
|
from syrupy.filters import props
|
||||||
|
|
||||||
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL, CONF_FILENAMES
|
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
|
||||||
from homeassistant.components.openai_conversation.const import DOMAIN
|
from homeassistant.components.openai_conversation.const import DOMAIN
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -394,7 +394,7 @@ async def test_generate_content_service(
|
|||||||
patch(
|
patch(
|
||||||
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
|
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
|
||||||
) as mock_b64encode,
|
) as mock_b64encode,
|
||||||
patch("builtins.open", mock_open(read_data="ABC")) as mock_file,
|
patch("pathlib.Path.read_bytes", Mock(return_value=b"ABC")) as mock_file,
|
||||||
patch("pathlib.Path.exists", return_value=True),
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
):
|
):
|
||||||
@ -434,15 +434,13 @@ async def test_generate_content_service(
|
|||||||
assert len(mock_create.mock_calls) == 1
|
assert len(mock_create.mock_calls) == 1
|
||||||
assert mock_create.mock_calls[0][2] == expected_args
|
assert mock_create.mock_calls[0][2] == expected_args
|
||||||
assert mock_b64encode.call_count == number_of_files
|
assert mock_b64encode.call_count == number_of_files
|
||||||
for idx, file in enumerate(service_data[CONF_FILENAMES]):
|
assert mock_file.call_count == number_of_files
|
||||||
assert mock_file.call_args_list[idx][0][0] == file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
(
|
(
|
||||||
"service_data",
|
"service_data",
|
||||||
"error",
|
"error",
|
||||||
"number_of_files",
|
|
||||||
"exists_side_effect",
|
"exists_side_effect",
|
||||||
"is_allowed_side_effect",
|
"is_allowed_side_effect",
|
||||||
),
|
),
|
||||||
@ -450,7 +448,6 @@ async def test_generate_content_service(
|
|||||||
(
|
(
|
||||||
{"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]},
|
{"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]},
|
||||||
"`/a/b/c.jpg` does not exist",
|
"`/a/b/c.jpg` does not exist",
|
||||||
0,
|
|
||||||
[False],
|
[False],
|
||||||
[True],
|
[True],
|
||||||
),
|
),
|
||||||
@ -460,14 +457,12 @@ async def test_generate_content_service(
|
|||||||
"filenames": ["/a/b/c.jpg", "d/e/f.png"],
|
"filenames": ["/a/b/c.jpg", "d/e/f.png"],
|
||||||
},
|
},
|
||||||
"Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`",
|
"Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`",
|
||||||
1,
|
|
||||||
[True, True],
|
[True, True],
|
||||||
[True, False],
|
[True, False],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.mov"]},
|
{"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.mov"]},
|
||||||
"Only images and PDF are supported by the OpenAI API,`/a/b/c.mov` is not an image file or PDF",
|
"Only images and PDF are supported by the OpenAI API,`/a/b/c.mov` is not an image file or PDF",
|
||||||
1,
|
|
||||||
[True],
|
[True],
|
||||||
[True],
|
[True],
|
||||||
),
|
),
|
||||||
@ -479,7 +474,6 @@ async def test_generate_content_service_invalid(
|
|||||||
mock_init_component,
|
mock_init_component,
|
||||||
service_data,
|
service_data,
|
||||||
error,
|
error,
|
||||||
number_of_files,
|
|
||||||
exists_side_effect,
|
exists_side_effect,
|
||||||
is_allowed_side_effect,
|
is_allowed_side_effect,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -491,9 +485,7 @@ async def test_generate_content_service_invalid(
|
|||||||
"openai.resources.responses.AsyncResponses.create",
|
"openai.resources.responses.AsyncResponses.create",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
) as mock_create,
|
) as mock_create,
|
||||||
patch(
|
patch("base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]),
|
||||||
"base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"]
|
|
||||||
) as mock_b64encode,
|
|
||||||
patch("builtins.open", mock_open(read_data="ABC")),
|
patch("builtins.open", mock_open(read_data="ABC")),
|
||||||
patch("pathlib.Path.exists", side_effect=exists_side_effect),
|
patch("pathlib.Path.exists", side_effect=exists_side_effect),
|
||||||
patch.object(
|
patch.object(
|
||||||
@ -509,7 +501,6 @@ async def test_generate_content_service_invalid(
|
|||||||
return_response=True,
|
return_response=True,
|
||||||
)
|
)
|
||||||
assert len(mock_create.mock_calls) == 0
|
assert len(mock_create.mock_calls) == 0
|
||||||
assert mock_b64encode.call_count == number_of_files
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user