mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
OpenAI Responses API (#140713)
This commit is contained in:
parent
214d14b06b
commit
bb7b5b9ccb
@ -7,21 +7,15 @@ from mimetypes import guess_file_type
|
||||
from pathlib import Path
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ImageURL,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_param import (
|
||||
ChatCompletionContentPartParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_text_param import (
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
Response,
|
||||
ResponseInputImageParam,
|
||||
ResponseInputMessageContentListParam,
|
||||
ResponseInputParam,
|
||||
ResponseInputTextParam,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -44,10 +38,18 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_FILENAMES,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
|
||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||
@ -112,17 +114,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
translation_placeholders={"config_entry": entry_id},
|
||||
)
|
||||
|
||||
model: str = entry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
model: str = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
client: openai.AsyncClient = entry.runtime_data
|
||||
|
||||
prompt_parts: list[ChatCompletionContentPartParam] = [
|
||||
ChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
text=call.data[CONF_PROMPT],
|
||||
)
|
||||
content: ResponseInputMessageContentListParam = [
|
||||
ResponseInputTextParam(type="input_text", text=call.data[CONF_PROMPT])
|
||||
]
|
||||
|
||||
def append_files_to_prompt() -> None:
|
||||
def append_files_to_content() -> None:
|
||||
for filename in call.data[CONF_FILENAMES]:
|
||||
if not hass.config.is_allowed_path(filename):
|
||||
raise HomeAssistantError(
|
||||
@ -138,46 +137,52 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"Only images are supported by the OpenAI API,"
|
||||
f"`{filename}` is not an image file"
|
||||
)
|
||||
prompt_parts.append(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=ImageURL(
|
||||
url=f"data:{mime_type};base64,{base64_file}"
|
||||
),
|
||||
content.append(
|
||||
ResponseInputImageParam(
|
||||
type="input_image",
|
||||
file_id=filename,
|
||||
image_url=f"data:{mime_type};base64,{base64_file}",
|
||||
detail="auto",
|
||||
)
|
||||
)
|
||||
|
||||
if CONF_FILENAMES in call.data:
|
||||
await hass.async_add_executor_job(append_files_to_prompt)
|
||||
await hass.async_add_executor_job(append_files_to_content)
|
||||
|
||||
messages: list[ChatCompletionUserMessageParam] = [
|
||||
ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=prompt_parts,
|
||||
)
|
||||
messages: ResponseInputParam = [
|
||||
EasyInputMessageParam(type="message", role="user", content=content)
|
||||
]
|
||||
|
||||
try:
|
||||
response: ChatCompletion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
n=1,
|
||||
response_format={
|
||||
"type": "json_object",
|
||||
},
|
||||
)
|
||||
model_args = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"max_output_tokens": entry.options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
"top_p": entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": entry.options.get(
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
),
|
||||
"user": call.context.user_id,
|
||||
"store": False,
|
||||
}
|
||||
|
||||
if model.startswith("o"):
|
||||
model_args["reasoning"] = {
|
||||
"effort": entry.options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
}
|
||||
|
||||
response: Response = await client.responses.create(**model_args)
|
||||
|
||||
except openai.OpenAIError as err:
|
||||
raise HomeAssistantError(f"Error generating content: {err}") from err
|
||||
except FileNotFoundError as err:
|
||||
raise HomeAssistantError(f"Error generating content: {err}") from err
|
||||
|
||||
response_text: str = ""
|
||||
for response_choice in response.choices:
|
||||
if response_choice.message.content is not None:
|
||||
response_text += response_choice.message.content.strip()
|
||||
|
||||
return {"text": response_text}
|
||||
return {"text": response.output_text}
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
|
@ -2,21 +2,25 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
from openai._types import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallParam,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
@ -60,123 +64,81 @@ async def async_setup_entry(
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> ChatCompletionToolParam:
|
||||
) -> FunctionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(
|
||||
return FunctionToolParam(
|
||||
type="function",
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
description=tool.description,
|
||||
strict=False,
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam:
|
||||
) -> ResponseInputParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if content.role == "tool_result":
|
||||
assert type(content) is conversation.ToolResultContent
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
if content.role != "assistant" or not content.tool_calls:
|
||||
role: Literal["system", "user", "assistant", "developer"] = content.role
|
||||
messages: ResponseInputParam = []
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return [
|
||||
FunctionCallOutput(
|
||||
type="function_call_output",
|
||||
call_id=content.tool_call_id,
|
||||
output=json.dumps(content.tool_result),
|
||||
)
|
||||
]
|
||||
|
||||
if content.content:
|
||||
role: Literal["user", "assistant", "system", "developer"] = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": content.role, "content": content.content},
|
||||
messages.append(
|
||||
EasyInputMessageParam(type="message", role=role, content=content.content)
|
||||
)
|
||||
|
||||
# Handle the Assistant content including tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
type="function",
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
messages.extend(
|
||||
# https://github.com/openai/openai-python/issues/2205
|
||||
ResponseFunctionToolCallParam( # type: ignore[typeddict-item]
|
||||
type="function_call",
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
call_id=tool_call.id,
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
],
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncStream[ChatCompletionChunk],
|
||||
result: AsyncStream[ResponseStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
current_tool_call: dict | None = None
|
||||
async for event in result:
|
||||
LOGGER.debug("Received event: %s", event)
|
||||
|
||||
async for chunk in result:
|
||||
LOGGER.debug("Received chunk: %s", chunk)
|
||||
choice = chunk.choices[0]
|
||||
|
||||
if choice.finish_reason:
|
||||
if current_tool_call:
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call["id"],
|
||||
tool_name=current_tool_call["tool_name"],
|
||||
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
break
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# We can yield delta messages not continuing or starting tool calls
|
||||
if current_tool_call is None and not delta.tool_calls:
|
||||
yield { # type: ignore[misc]
|
||||
key: value
|
||||
for key in ("role", "content")
|
||||
if (value := getattr(delta, key)) is not None
|
||||
}
|
||||
continue
|
||||
|
||||
# When doing tool calls, we should always have a tool call
|
||||
# object or we have gotten stopped above with a finish_reason set.
|
||||
if (
|
||||
not delta.tool_calls
|
||||
or not (delta_tool_call := delta.tool_calls[0])
|
||||
or not delta_tool_call.function
|
||||
):
|
||||
raise ValueError("Expected delta with tool call")
|
||||
|
||||
if current_tool_call and delta_tool_call.index == current_tool_call["index"]:
|
||||
current_tool_call["tool_args"] += delta_tool_call.function.arguments or ""
|
||||
continue
|
||||
|
||||
# We got tool call with new index, so we need to yield the previous
|
||||
if current_tool_call:
|
||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||
if isinstance(event.item, ResponseOutputMessage):
|
||||
yield {"role": event.item.role}
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
current_tool_call = event.item
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
current_tool_call.arguments += event.delta
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
current_tool_call.status = "completed"
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call["id"],
|
||||
tool_name=current_tool_call["tool_name"],
|
||||
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=json.loads(current_tool_call.arguments),
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
current_tool_call = {
|
||||
"index": delta_tool_call.index,
|
||||
"id": delta_tool_call.id,
|
||||
"tool_name": delta_tool_call.function.name,
|
||||
"tool_args": delta_tool_call.function.arguments or "",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
@ -241,7 +203,7 @@ class OpenAIConversationEntity(
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
tools: list[ToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
@ -249,7 +211,11 @@ class OpenAIConversationEntity(
|
||||
]
|
||||
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
messages = [_convert_content_to_param(content) for content in chat_log.content]
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
for m in _convert_content_to_param(content)
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
@ -257,24 +223,28 @@ class OpenAIConversationEntity(
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
model_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools or NOT_GIVEN,
|
||||
"max_completion_tokens": options.get(
|
||||
"input": messages,
|
||||
"max_output_tokens": options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
if model.startswith("o"):
|
||||
model_args["reasoning_effort"] = options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
model_args["reasoning"] = {
|
||||
"effort": options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
result = await client.chat.completions.create(**model_args)
|
||||
result = await client.responses.create(**model_args)
|
||||
except openai.RateLimitError as err:
|
||||
LOGGER.error("Rate limited by OpenAI: %s", err)
|
||||
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
||||
@ -282,14 +252,10 @@ class OpenAIConversationEntity(
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
_convert_content_to_param(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(result)
|
||||
)
|
||||
]
|
||||
)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(result)
|
||||
):
|
||||
messages.extend(_convert_content_to_param(content))
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
@ -3,14 +3,28 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from httpx import Response
|
||||
import httpx
|
||||
from openai import AuthenticationError, RateLimitError
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
from openai.types import ResponseFormatText
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextConfig,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
)
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
@ -28,40 +42,65 @@ from tests.components.conversation import (
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
ASSIST_RESPONSE_FINISH = (
|
||||
# Assistant message
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||
),
|
||||
# Finish stream
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, finish_reason="stop", delta=ChoiceDelta())],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_stream() -> Generator[AsyncMock]:
|
||||
"""Mock stream response."""
|
||||
|
||||
async def mock_generator(stream):
|
||||
for value in stream:
|
||||
async def mock_generator(events, **kwargs):
|
||||
response = Response(
|
||||
id="resp_A",
|
||||
created_at=1700000000,
|
||||
error=None,
|
||||
incomplete_details=None,
|
||||
instructions=kwargs.get("instructions"),
|
||||
metadata=kwargs.get("metadata", {}),
|
||||
model=kwargs.get("model", "gpt-4o-mini"),
|
||||
object="response",
|
||||
output=[],
|
||||
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
tool_choice=kwargs.get("tool_choice", "auto"),
|
||||
tools=kwargs.get("tools"),
|
||||
top_p=kwargs.get("top_p", 1.0),
|
||||
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
||||
previous_response_id=kwargs.get("previous_response_id"),
|
||||
reasoning=kwargs.get("reasoning"),
|
||||
status="in_progress",
|
||||
text=kwargs.get(
|
||||
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
||||
),
|
||||
truncation=kwargs.get("truncation", "disabled"),
|
||||
usage=None,
|
||||
user=kwargs.get("user"),
|
||||
store=kwargs.get("store", True),
|
||||
)
|
||||
yield ResponseCreatedEvent(
|
||||
response=response,
|
||||
type="response.created",
|
||||
)
|
||||
yield ResponseInProgressEvent(
|
||||
response=response,
|
||||
type="response.in_progress",
|
||||
)
|
||||
|
||||
for value in events:
|
||||
if isinstance(value, ResponseOutputItemDoneEvent):
|
||||
response.output.append(value.item)
|
||||
yield value
|
||||
|
||||
response.status = "completed"
|
||||
yield ResponseCompletedEvent(
|
||||
response=response,
|
||||
type="response.completed",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
AsyncMock(),
|
||||
) as mock_create:
|
||||
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||
mock_create.return_value.pop(0)
|
||||
mock_create.return_value.pop(0), **kwargs
|
||||
)
|
||||
|
||||
yield mock_create
|
||||
@ -99,13 +138,17 @@ async def test_entity(
|
||||
[
|
||||
(
|
||||
RateLimitError(
|
||||
response=Response(status_code=429, request=""), body=None, message=None
|
||||
response=httpx.Response(status_code=429, request=""),
|
||||
body=None,
|
||||
message=None,
|
||||
),
|
||||
"Rate limited or insufficient funds",
|
||||
),
|
||||
(
|
||||
AuthenticationError(
|
||||
response=Response(status_code=401, request=""), body=None, message=None
|
||||
response=httpx.Response(status_code=401, request=""),
|
||||
body=None,
|
||||
message=None,
|
||||
),
|
||||
"Error talking to OpenAI",
|
||||
),
|
||||
@ -120,7 +163,7 @@ async def test_error_handling(
|
||||
) -> None:
|
||||
"""Test that we handle errors when calling completion API."""
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=exception,
|
||||
):
|
||||
@ -144,6 +187,165 @@ async def test_conversation_agent(
|
||||
assert agent.supported_languages == "*"
|
||||
|
||||
|
||||
def create_message_item(
|
||||
id: str, text: str | list[str], output_index: int
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a message item."""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
||||
events = [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseOutputMessage(
|
||||
id=id,
|
||||
content=[],
|
||||
type="message",
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseContentPartAddedEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
part=content,
|
||||
type="response.content_part.added",
|
||||
),
|
||||
]
|
||||
|
||||
content.text = "".join(text)
|
||||
events.extend(
|
||||
ResponseTextDeltaEvent(
|
||||
content_index=0,
|
||||
delta=delta,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.output_text.delta",
|
||||
)
|
||||
for delta in text
|
||||
)
|
||||
|
||||
events.extend(
|
||||
[
|
||||
ResponseTextDoneEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
text="".join(text),
|
||||
type="response.output_text.done",
|
||||
),
|
||||
ResponseContentPartDoneEvent(
|
||||
content_index=0,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
part=content,
|
||||
type="response.content_part.done",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseOutputMessage(
|
||||
id=id,
|
||||
content=[content],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_function_tool_call_item(
|
||||
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a function tool call item."""
|
||||
if isinstance(arguments, str):
|
||||
arguments = [arguments]
|
||||
|
||||
events = [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseFunctionToolCall(
|
||||
id=id,
|
||||
arguments="",
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
type="function_call",
|
||||
status="in_progress",
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.added",
|
||||
)
|
||||
]
|
||||
|
||||
events.extend(
|
||||
ResponseFunctionCallArgumentsDeltaEvent(
|
||||
delta=delta,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.function_call_arguments.delta",
|
||||
)
|
||||
for delta in arguments
|
||||
)
|
||||
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDoneEvent(
|
||||
arguments="".join(arguments),
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.function_call_arguments.done",
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseFunctionToolCall(
|
||||
id=id,
|
||||
arguments="".join(arguments),
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
type="function_call",
|
||||
status="completed",
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||
"""Create a reasoning item."""
|
||||
return [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseReasoningItem(
|
||||
id=id,
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
status=None,
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseReasoningItem(
|
||||
id=id,
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
status=None,
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
@ -156,111 +358,27 @@ async def test_function_call(
|
||||
mock_create_stream.return_value = [
|
||||
# Initial conversation
|
||||
(
|
||||
# Wait for the model to think
|
||||
*create_reasoning_item(id="rs_A", output_index=0),
|
||||
# First tool call
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='{"para',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='m1":"call1"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
*create_function_tool_call_item(
|
||||
id="fc_1",
|
||||
arguments=['{"para', 'm1":"call1"}'],
|
||||
call_id="call_call_1",
|
||||
name="test_tool",
|
||||
output_index=1,
|
||||
),
|
||||
# Second tool call
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_2",
|
||||
index=1,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments='{"param1":"call2"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
# Finish stream
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(index=0, finish_reason="tool_calls", delta=ChoiceDelta())
|
||||
],
|
||||
*create_function_tool_call_item(
|
||||
id="fc_2",
|
||||
arguments='{"param1":"call2"}',
|
||||
call_id="call_call_2",
|
||||
name="test_tool",
|
||||
output_index=2,
|
||||
),
|
||||
),
|
||||
# Response after tool responses
|
||||
ASSIST_RESPONSE_FINISH,
|
||||
create_message_item(id="msg_A", text="Cool", output_index=0),
|
||||
]
|
||||
mock_chat_log.mock_tool_results(
|
||||
{
|
||||
@ -288,99 +406,27 @@ async def test_function_call(
|
||||
(
|
||||
"Test function call started with missing arguments",
|
||||
(
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||
*create_function_tool_call_item(
|
||||
id="fc_1",
|
||||
arguments=[],
|
||||
call_id="call_call_1",
|
||||
name="test_tool",
|
||||
output_index=0,
|
||||
),
|
||||
*create_message_item(id="msg_A", text="Cool", output_index=1),
|
||||
),
|
||||
),
|
||||
(
|
||||
"Test invalid JSON",
|
||||
(
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='{"para',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(content="Cool"),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
*create_function_tool_call_item(
|
||||
id="fc_1",
|
||||
arguments=['{"para'],
|
||||
call_id="call_call_1",
|
||||
name="test_tool",
|
||||
output_index=0,
|
||||
),
|
||||
*create_message_item(id="msg_A", text="Cool", output_index=1),
|
||||
),
|
||||
),
|
||||
],
|
||||
@ -392,7 +438,7 @@ async def test_function_call_invalid(
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
description: str,
|
||||
messages: tuple[ChatCompletionChunk],
|
||||
messages: tuple[ResponseStreamEvent],
|
||||
) -> None:
|
||||
"""Test function call containing invalid data."""
|
||||
mock_create_stream.return_value = [messages]
|
||||
@ -432,7 +478,9 @@ async def test_assist_api_tools_conversion(
|
||||
hass.states.async_set(f"{component}.test", "on")
|
||||
async_expose_entity(hass, "conversation", f"{component}.test", True)
|
||||
|
||||
mock_create_stream.return_value = [ASSIST_RESPONSE_FINISH]
|
||||
mock_create_stream.return_value = [
|
||||
create_message_item(id="msg_A", text="Cool", output_index=0)
|
||||
]
|
||||
|
||||
await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id="conversation.openai"
|
||||
|
@ -2,17 +2,16 @@
|
||||
|
||||
from unittest.mock import AsyncMock, mock_open, patch
|
||||
|
||||
from httpx import Request, Response
|
||||
import httpx
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.image import Image
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.responses import Response, ResponseOutputMessage, ResponseOutputText
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.openai_conversation import CONF_FILENAMES
|
||||
@ -117,8 +116,8 @@ async def test_generate_image_service_error(
|
||||
patch(
|
||||
"openai.resources.images.AsyncImages.generate",
|
||||
side_effect=RateLimitError(
|
||||
response=Response(
|
||||
status_code=500, request=Request(method="GET", url="")
|
||||
response=httpx.Response(
|
||||
status_code=500, request=httpx.Request(method="GET", url="")
|
||||
),
|
||||
body=None,
|
||||
message="Reason",
|
||||
@ -202,13 +201,13 @@ async def test_invalid_config_entry(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(
|
||||
APIConnectionError(request=Request(method="GET", url="test")),
|
||||
APIConnectionError(request=httpx.Request(method="GET", url="test")),
|
||||
"Connection error",
|
||||
),
|
||||
(
|
||||
AuthenticationError(
|
||||
response=Response(
|
||||
status_code=500, request=Request(method="GET", url="test")
|
||||
response=httpx.Response(
|
||||
status_code=500, request=httpx.Request(method="GET", url="test")
|
||||
),
|
||||
body=None,
|
||||
message="",
|
||||
@ -217,8 +216,8 @@ async def test_invalid_config_entry(
|
||||
),
|
||||
(
|
||||
BadRequestError(
|
||||
response=Response(
|
||||
status_code=500, request=Request(method="GET", url="test")
|
||||
response=httpx.Response(
|
||||
status_code=500, request=httpx.Request(method="GET", url="test")
|
||||
),
|
||||
body=None,
|
||||
message="",
|
||||
@ -250,11 +249,11 @@ async def test_init_error(
|
||||
(
|
||||
{"prompt": "Picture of a dog", "filenames": []},
|
||||
{
|
||||
"messages": [
|
||||
"input": [
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"type": "input_text",
|
||||
"text": "Picture of a dog",
|
||||
},
|
||||
],
|
||||
@ -266,18 +265,18 @@ async def test_init_error(
|
||||
(
|
||||
{"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]},
|
||||
{
|
||||
"messages": [
|
||||
"input": [
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"type": "input_text",
|
||||
"text": "Picture of a dog",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "",
|
||||
},
|
||||
"type": "input_image",
|
||||
"image_url": "",
|
||||
"detail": "auto",
|
||||
"file_id": "/a/b/c.jpg",
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -291,24 +290,24 @@ async def test_init_error(
|
||||
"filenames": ["/a/b/c.jpg", "d/e/f.jpg"],
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
"input": [
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"type": "input_text",
|
||||
"text": "Picture of a dog",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "",
|
||||
},
|
||||
"type": "input_image",
|
||||
"image_url": "",
|
||||
"detail": "auto",
|
||||
"file_id": "/a/b/c.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "",
|
||||
},
|
||||
"type": "input_image",
|
||||
"image_url": "",
|
||||
"detail": "auto",
|
||||
"file_id": "d/e/f.jpg",
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -329,13 +328,17 @@ async def test_generate_content_service(
|
||||
"""Test generate content service."""
|
||||
service_data["config_entry"] = mock_config_entry.entry_id
|
||||
expected_args["model"] = "gpt-4o-mini"
|
||||
expected_args["n"] = 1
|
||||
expected_args["response_format"] = {"type": "json_object"}
|
||||
expected_args["messages"][0]["role"] = "user"
|
||||
expected_args["max_output_tokens"] = 150
|
||||
expected_args["top_p"] = 1.0
|
||||
expected_args["temperature"] = 1.0
|
||||
expected_args["user"] = None
|
||||
expected_args["store"] = False
|
||||
expected_args["input"][0]["type"] = "message"
|
||||
expected_args["input"][0]["role"] = "user"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create,
|
||||
patch(
|
||||
@ -345,19 +348,27 @@ async def test_generate_content_service(
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
):
|
||||
mock_create.return_value = ChatCompletion(
|
||||
id="",
|
||||
model="",
|
||||
created=1700000000,
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="This is the response",
|
||||
),
|
||||
mock_create.return_value = Response(
|
||||
object="response",
|
||||
id="resp_A",
|
||||
created_at=1700000000,
|
||||
model="gpt-4o-mini",
|
||||
parallel_tool_calls=True,
|
||||
tool_choice="auto",
|
||||
tools=[],
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_A",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text",
|
||||
text="This is the response",
|
||||
annotations=[],
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -427,7 +438,7 @@ async def test_generate_content_service_invalid(
|
||||
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create,
|
||||
patch(
|
||||
@ -459,10 +470,10 @@ async def test_generate_content_service_error(
|
||||
"""Test generate content service handles errors."""
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
side_effect=RateLimitError(
|
||||
response=Response(
|
||||
status_code=417, request=Request(method="GET", url="")
|
||||
response=httpx.Response(
|
||||
status_code=417, request=httpx.Request(method="GET", url="")
|
||||
),
|
||||
body=None,
|
||||
message="Reason",
|
||||
|
Loading…
x
Reference in New Issue
Block a user