Preserve reasoning during tool calls for openai_conversation (#143699)

Preserve reasoning after tool calls for openai_conversation
This commit is contained in:
Denis Shulyaka 2025-04-26 04:12:23 +03:00 committed by GitHub
parent 4c9cd70f65
commit 7074331461
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
import json import json
from typing import Any, Literal from typing import Any, Literal, cast
import openai import openai
from openai._streaming import AsyncStream from openai._streaming import AsyncStream
@ -19,7 +19,11 @@ from openai.types.responses import (
ResponseIncompleteEvent, ResponseIncompleteEvent,
ResponseInputParam, ResponseInputParam,
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage, ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent, ResponseStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ToolParam, ToolParam,
@ -127,6 +131,7 @@ def _convert_content_to_param(
async def _transform_stream( async def _transform_stream(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent], result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format.""" """Transform an OpenAI delta stream into HA format."""
async for event in result: async for event in result:
@ -137,6 +142,15 @@ async def _transform_stream(
yield {"role": event.item.role} yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall): elif isinstance(event.item, ResponseFunctionToolCall):
current_tool_call = event.item current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent): elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta} yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
@ -314,7 +328,6 @@ class OpenAIConversationEntity(
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id, "user": chat_log.conversation_id,
"store": False,
"stream": True, "stream": True,
} }
if tools: if tools:
@ -326,6 +339,8 @@ class OpenAIConversationEntity(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
) )
} }
else:
model_args["store"] = False
try: try:
result = await client.responses.create(**model_args) result = await client.responses.create(**model_args)
@ -337,9 +352,10 @@ class OpenAIConversationEntity(
raise HomeAssistantError("Error talking to OpenAI") from err raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream( async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(chat_log, result) user_input.agent_id, _transform_stream(chat_log, result, messages)
): ):
messages.extend(_convert_content_to_param(content)) if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results: if not chat_log.unresponded_tool_results:
break break

View File

@ -586,6 +586,11 @@ async def test_function_call(
agent_id="conversation.openai", agent_id="conversation.openai",
) )
assert mock_create_stream.call_args.kwargs["input"][2] == {
"id": "rs_A",
"summary": [],
"type": "reasoning",
}
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Don't test the prompt, as it's not deterministic # Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot assert mock_chat_log.content[1:] == snapshot