Preserve reasoning during tool calls for openai_conversation (#143699)

Preserve reasoning after tool calls for openai_conversation
This commit is contained in:
Denis Shulyaka 2025-04-26 04:12:23 +03:00 committed by GitHub
parent 4c9cd70f65
commit 7074331461
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
@ -19,7 +19,11 @@ from openai.types.responses import (
ResponseIncompleteEvent,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
@ -127,6 +131,7 @@ def _convert_content_to_param(
async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
async for event in result:
@ -137,6 +142,15 @@ async def _transform_stream(
yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall):
current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
@ -314,7 +328,6 @@ class OpenAIConversationEntity(
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"store": False,
"stream": True,
}
if tools:
@ -326,6 +339,8 @@ class OpenAIConversationEntity(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
else:
model_args["store"] = False
try:
result = await client.responses.create(**model_args)
@ -337,8 +352,9 @@ class OpenAIConversationEntity(
raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(chat_log, result)
user_input.agent_id, _transform_stream(chat_log, result, messages)
):
if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results:

View File

@ -586,6 +586,11 @@ async def test_function_call(
agent_id="conversation.openai",
)
assert mock_create_stream.call_args.kwargs["input"][2] == {
"id": "rs_A",
"summary": [],
"type": "reasoning",
}
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot