mirror of
https://github.com/home-assistant/core.git
synced 2025-11-12 20:40:18 +00:00
Add OpenAI AI Task entity (#148295)
This commit is contained in:
@@ -1,41 +1,15 @@
|
||||
"""Tests for the OpenAI integration."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
from openai import AuthenticationError, RateLimitError
|
||||
from openai.types import ResponseFormatText
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseError,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextConfig,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
)
|
||||
from openai.types.responses.response import IncompleteDetails
|
||||
from openai.types.responses.response_function_web_search import ActionSearch
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
@@ -55,6 +29,13 @@ from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import (
|
||||
create_function_tool_call_item,
|
||||
create_message_item,
|
||||
create_reasoning_item,
|
||||
create_web_search_item,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
@@ -62,97 +43,6 @@ from tests.components.conversation import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_stream() -> Generator[AsyncMock]:
|
||||
"""Mock stream response."""
|
||||
|
||||
async def mock_generator(events, **kwargs):
|
||||
response = Response(
|
||||
id="resp_A",
|
||||
created_at=1700000000,
|
||||
error=None,
|
||||
incomplete_details=None,
|
||||
instructions=kwargs.get("instructions"),
|
||||
metadata=kwargs.get("metadata", {}),
|
||||
model=kwargs.get("model", "gpt-4o-mini"),
|
||||
object="response",
|
||||
output=[],
|
||||
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
tool_choice=kwargs.get("tool_choice", "auto"),
|
||||
tools=kwargs.get("tools"),
|
||||
top_p=kwargs.get("top_p", 1.0),
|
||||
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
||||
previous_response_id=kwargs.get("previous_response_id"),
|
||||
reasoning=kwargs.get("reasoning"),
|
||||
status="in_progress",
|
||||
text=kwargs.get(
|
||||
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
||||
),
|
||||
truncation=kwargs.get("truncation", "disabled"),
|
||||
usage=None,
|
||||
user=kwargs.get("user"),
|
||||
store=kwargs.get("store", True),
|
||||
)
|
||||
yield ResponseCreatedEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
type="response.created",
|
||||
)
|
||||
yield ResponseInProgressEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
type="response.in_progress",
|
||||
)
|
||||
response.status = "completed"
|
||||
|
||||
for value in events:
|
||||
if isinstance(value, ResponseOutputItemDoneEvent):
|
||||
response.output.append(value.item)
|
||||
elif isinstance(value, IncompleteDetails):
|
||||
response.status = "incomplete"
|
||||
response.incomplete_details = value
|
||||
break
|
||||
if isinstance(value, ResponseError):
|
||||
response.status = "failed"
|
||||
response.error = value
|
||||
break
|
||||
|
||||
yield value
|
||||
|
||||
if isinstance(value, ResponseErrorEvent):
|
||||
return
|
||||
|
||||
if response.status == "incomplete":
|
||||
yield ResponseIncompleteEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
type="response.incomplete",
|
||||
)
|
||||
elif response.status == "failed":
|
||||
yield ResponseFailedEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
type="response.failed",
|
||||
)
|
||||
else:
|
||||
yield ResponseCompletedEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
type="response.completed",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
AsyncMock(),
|
||||
) as mock_create:
|
||||
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||
mock_create.return_value.pop(0), **kwargs
|
||||
)
|
||||
|
||||
yield mock_create
|
||||
|
||||
|
||||
async def test_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@@ -347,225 +237,6 @@ async def test_conversation_agent(
|
||||
assert agent.supported_languages == "*"
|
||||
|
||||
|
||||
def create_message_item(
|
||||
id: str, text: str | list[str], output_index: int
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a message item."""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
||||
events = [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseOutputMessage(
|
||||
id=id,
|
||||
content=[],
|
||||
type="message",
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseContentPartAddedEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
part=content,
|
||||
sequence_number=0,
|
||||
type="response.content_part.added",
|
||||
),
|
||||
]
|
||||
|
||||
content.text = "".join(text)
|
||||
events.extend(
|
||||
ResponseTextDeltaEvent(
|
||||
content_index=0,
|
||||
delta=delta,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_text.delta",
|
||||
)
|
||||
for delta in text
|
||||
)
|
||||
|
||||
events.extend(
|
||||
[
|
||||
ResponseTextDoneEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
text="".join(text),
|
||||
sequence_number=0,
|
||||
type="response.output_text.done",
|
||||
),
|
||||
ResponseContentPartDoneEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
part=content,
|
||||
sequence_number=0,
|
||||
type="response.content_part.done",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseOutputMessage(
|
||||
id=id,
|
||||
content=[content],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_function_tool_call_item(
|
||||
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a function tool call item."""
|
||||
if isinstance(arguments, str):
|
||||
arguments = [arguments]
|
||||
|
||||
events = [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseFunctionToolCall(
|
||||
id=id,
|
||||
arguments="",
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
type="function_call",
|
||||
status="in_progress",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.added",
|
||||
)
|
||||
]
|
||||
|
||||
events.extend(
|
||||
ResponseFunctionCallArgumentsDeltaEvent(
|
||||
delta=delta,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.function_call_arguments.delta",
|
||||
)
|
||||
for delta in arguments
|
||||
)
|
||||
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDoneEvent(
|
||||
arguments="".join(arguments),
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.function_call_arguments.done",
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseFunctionToolCall(
|
||||
id=id,
|
||||
arguments="".join(arguments),
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
type="function_call",
|
||||
status="completed",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||
"""Create a reasoning item."""
|
||||
return [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseReasoningItem(
|
||||
id=id,
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
status=None,
|
||||
encrypted_content="AAA",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseReasoningItem(
|
||||
id=id,
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
status=None,
|
||||
encrypted_content="AAABBB",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||
"""Create a web search call item."""
|
||||
return [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseFunctionWebSearch(
|
||||
id=id,
|
||||
status="in_progress",
|
||||
action=ActionSearch(query="query", type="search"),
|
||||
type="web_search_call",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseWebSearchCallInProgressEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.web_search_call.in_progress",
|
||||
),
|
||||
ResponseWebSearchCallSearchingEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.web_search_call.searching",
|
||||
),
|
||||
ResponseWebSearchCallCompletedEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.web_search_call.completed",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseFunctionWebSearch(
|
||||
id=id,
|
||||
status="completed",
|
||||
action=ActionSearch(query="query", type="search"),
|
||||
type="web_search_call",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_reasoning_model: MockConfigEntry,
|
||||
|
||||
Reference in New Issue
Block a user