mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
OpenAI: Add attachment support to AI task (#148676)
This commit is contained in:
parent
23a8442abe
commit
611f86cf8c
@ -39,7 +39,10 @@ class OpenAITaskEntity(
|
||||
):
|
||||
"""OpenAI AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
_attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
|
@ -345,6 +345,26 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
for content in chat_log.content
|
||||
for m in _convert_content_to_param(content)
|
||||
]
|
||||
|
||||
last_content = chat_log.content[-1]
|
||||
|
||||
# Handle attachments by adding them to the last user message
|
||||
if last_content.role == "user" and last_content.attachments:
|
||||
files = await async_prepare_files_for_prompt(
|
||||
self.hass,
|
||||
[a.path for a in last_content.attachments],
|
||||
)
|
||||
last_message = messages[-1]
|
||||
assert (
|
||||
last_message["type"] == "message"
|
||||
and last_message["role"] == "user"
|
||||
and isinstance(last_message["content"], str)
|
||||
)
|
||||
last_message["content"] = [
|
||||
{"type": "input_text", "text": last_message["content"]}, # type: ignore[list-item]
|
||||
*files, # type: ignore[list-item]
|
||||
]
|
||||
|
||||
if structure and structure_name:
|
||||
model_args["text"] = {
|
||||
"format": {
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Test AI Task platform of OpenAI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.components import ai_task, media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er, selector
|
||||
@ -122,3 +123,86 @@ async def test_generate_invalid_structured_data(
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_data_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_create_stream: AsyncMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task data generation with attachments."""
|
||||
entity_id = "ai_task.openai_ai_task"
|
||||
|
||||
# Mock the OpenAI response stream
|
||||
mock_create_stream.return_value = [
|
||||
create_message_item(id="msg_A", text="Hi there!", output_index=0)
|
||||
]
|
||||
|
||||
# Test with attachments
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
side_effect=[
|
||||
media_source.PlayMedia(
|
||||
url="http://example.com/doorbell_snapshot.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=Path("doorbell_snapshot.jpg"),
|
||||
),
|
||||
media_source.PlayMedia(
|
||||
url="http://example.com/context.txt",
|
||||
mime_type="text/plain",
|
||||
path=Path("context.txt"),
|
||||
),
|
||||
],
|
||||
),
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
# patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
patch(
|
||||
"homeassistant.components.openai_conversation.entity.guess_file_type",
|
||||
return_value=("image/jpeg", None),
|
||||
),
|
||||
patch("pathlib.Path.read_bytes", return_value=b"fake_image_data"),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
attachments=[
|
||||
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||
{"media_content_id": "media-source://media/context.txt"},
|
||||
],
|
||||
)
|
||||
|
||||
assert result.data == "Hi there!"
|
||||
|
||||
# Verify that the create stream was called with the correct parameters
|
||||
# The last call should have the user message with attachments
|
||||
call_args = mock_create_stream.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# Check that the input includes the attachments
|
||||
input_messages = call_args[1]["input"]
|
||||
assert len(input_messages) > 0
|
||||
|
||||
# Find the user message with attachments
|
||||
user_message_with_attachments = input_messages[-2]
|
||||
|
||||
assert user_message_with_attachments is not None
|
||||
assert isinstance(user_message_with_attachments["content"], list)
|
||||
assert len(user_message_with_attachments["content"]) == 3 # Text + attachments
|
||||
assert user_message_with_attachments["content"] == [
|
||||
{"type": "input_text", "text": "Test prompt"},
|
||||
{
|
||||
"detail": "auto",
|
||||
"image_url": "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh",
|
||||
"type": "input_image",
|
||||
},
|
||||
{
|
||||
"detail": "auto",
|
||||
"image_url": "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh",
|
||||
"type": "input_image",
|
||||
},
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user