mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
OpenAI: Add attachment support to AI task (#148676)
This commit is contained in:
parent
23a8442abe
commit
611f86cf8c
@ -39,7 +39,10 @@ class OpenAITaskEntity(
|
|||||||
):
|
):
|
||||||
"""OpenAI AI Task entity."""
|
"""OpenAI AI Task entity."""
|
||||||
|
|
||||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
_attr_supported_features = (
|
||||||
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_generate_data(
|
async def _async_generate_data(
|
||||||
self,
|
self,
|
||||||
|
@ -345,6 +345,26 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
for content in chat_log.content
|
for content in chat_log.content
|
||||||
for m in _convert_content_to_param(content)
|
for m in _convert_content_to_param(content)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
last_content = chat_log.content[-1]
|
||||||
|
|
||||||
|
# Handle attachments by adding them to the last user message
|
||||||
|
if last_content.role == "user" and last_content.attachments:
|
||||||
|
files = await async_prepare_files_for_prompt(
|
||||||
|
self.hass,
|
||||||
|
[a.path for a in last_content.attachments],
|
||||||
|
)
|
||||||
|
last_message = messages[-1]
|
||||||
|
assert (
|
||||||
|
last_message["type"] == "message"
|
||||||
|
and last_message["role"] == "user"
|
||||||
|
and isinstance(last_message["content"], str)
|
||||||
|
)
|
||||||
|
last_message["content"] = [
|
||||||
|
{"type": "input_text", "text": last_message["content"]}, # type: ignore[list-item]
|
||||||
|
*files, # type: ignore[list-item]
|
||||||
|
]
|
||||||
|
|
||||||
if structure and structure_name:
|
if structure and structure_name:
|
||||||
model_args["text"] = {
|
model_args["text"] = {
|
||||||
"format": {
|
"format": {
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
"""Test AI Task platform of OpenAI Conversation integration."""
|
"""Test AI Task platform of OpenAI Conversation integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import ai_task
|
from homeassistant.components import ai_task, media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity_registry as er, selector
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
@ -122,3 +123,86 @@ async def test_generate_invalid_structured_data(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_data_with_attachments(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task data generation with attachments."""
|
||||||
|
entity_id = "ai_task.openai_ai_task"
|
||||||
|
|
||||||
|
# Mock the OpenAI response stream
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_message_item(id="msg_A", text="Hi there!", output_index=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test with attachments
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
side_effect=[
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/doorbell_snapshot.jpg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
path=Path("doorbell_snapshot.jpg"),
|
||||||
|
),
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/context.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
path=Path("context.txt"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch("pathlib.Path.exists", return_value=True),
|
||||||
|
# patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.openai_conversation.entity.guess_file_type",
|
||||||
|
return_value=("image/jpeg", None),
|
||||||
|
),
|
||||||
|
patch("pathlib.Path.read_bytes", return_value=b"fake_image_data"),
|
||||||
|
):
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Test prompt",
|
||||||
|
attachments=[
|
||||||
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||||
|
{"media_content_id": "media-source://media/context.txt"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == "Hi there!"
|
||||||
|
|
||||||
|
# Verify that the create stream was called with the correct parameters
|
||||||
|
# The last call should have the user message with attachments
|
||||||
|
call_args = mock_create_stream.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
|
||||||
|
# Check that the input includes the attachments
|
||||||
|
input_messages = call_args[1]["input"]
|
||||||
|
assert len(input_messages) > 0
|
||||||
|
|
||||||
|
# Find the user message with attachments
|
||||||
|
user_message_with_attachments = input_messages[-2]
|
||||||
|
|
||||||
|
assert user_message_with_attachments is not None
|
||||||
|
assert isinstance(user_message_with_attachments["content"], list)
|
||||||
|
assert len(user_message_with_attachments["content"]) == 3 # Text + attachments
|
||||||
|
assert user_message_with_attachments["content"] == [
|
||||||
|
{"type": "input_text", "text": "Test prompt"},
|
||||||
|
{
|
||||||
|
"detail": "auto",
|
||||||
|
"image_url": "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh",
|
||||||
|
"type": "input_image",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"detail": "auto",
|
||||||
|
"image_url": "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh",
|
||||||
|
"type": "input_image",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user