Add OpenAI AI Task entity (#148295)

This commit is contained in:
Paulus Schoutsen
2025-07-10 23:08:56 +02:00
committed by GitHub
parent f0a636949a
commit 0e09a47476
14 changed files with 1152 additions and 463 deletions

View File

@@ -1,41 +1,15 @@
"""Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import httpx
from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response import IncompleteDetails
from openai.types.responses.response_function_web_search import ActionSearch
import pytest
from syrupy.assertion import SnapshotAssertion
@@ -55,6 +29,13 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from . import (
create_function_tool_call_item,
create_message_item,
create_reasoning_item,
create_web_search_item,
)
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
@@ -62,97 +43,6 @@ from tests.components.conversation import (
)
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@@ -347,225 +237,6 @@ async def test_conversation_agent(
assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
sequence_number=0,
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAA",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAABBB",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id,
status="in_progress",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id,
status="completed",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_reasoning_model: MockConfigEntry,