From 707433146198b7700af440cb0f8cc59e6b58b590 Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Sat, 26 Apr 2025 04:12:23 +0300 Subject: [PATCH] Preserve reasoning during tool calls for openai_conversation (#143699) Preserve reasoning after tool calls for openai_conversation --- .../openai_conversation/conversation.py | 24 +++++++++++++++---- .../openai_conversation/test_conversation.py | 5 ++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 026e18f3ce1..67e79e270d7 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator, Callable import json -from typing import Any, Literal +from typing import Any, Literal, cast import openai from openai._streaming import AsyncStream @@ -19,7 +19,11 @@ from openai.types.responses import ( ResponseIncompleteEvent, ResponseInputParam, ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, ResponseOutputMessage, + ResponseOutputMessageParam, + ResponseReasoningItem, + ResponseReasoningItemParam, ResponseStreamEvent, ResponseTextDeltaEvent, ToolParam, @@ -127,6 +131,7 @@ def _convert_content_to_param( async def _transform_stream( chat_log: conversation.ChatLog, result: AsyncStream[ResponseStreamEvent], + messages: ResponseInputParam, ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform an OpenAI delta stream into HA format.""" async for event in result: @@ -137,6 +142,15 @@ async def _transform_stream( yield {"role": event.item.role} elif isinstance(event.item, ResponseFunctionToolCall): current_tool_call = event.item + elif isinstance(event, ResponseOutputItemDoneEvent): + item = event.item.model_dump() + item.pop("status", None) + if isinstance(event.item, ResponseReasoningItem): + messages.append(cast(ResponseReasoningItemParam, item)) + elif isinstance(event.item, ResponseOutputMessage): + messages.append(cast(ResponseOutputMessageParam, item)) + elif isinstance(event.item, ResponseFunctionToolCall): + messages.append(cast(ResponseFunctionToolCallParam, item)) elif isinstance(event, ResponseTextDeltaEvent): yield {"content": event.delta} elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): @@ -314,7 +328,6 @@ class OpenAIConversationEntity( "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "user": chat_log.conversation_id, - "store": False, "stream": True, } if tools: @@ -326,6 +339,8 @@ class OpenAIConversationEntity( CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT ) } + else: + model_args["store"] = False try: result = await client.responses.create(**model_args) @@ -337,9 +352,10 @@ class OpenAIConversationEntity( raise HomeAssistantError("Error talking to OpenAI") from err async for content in chat_log.async_add_delta_content_stream( - user_input.agent_id, _transform_stream(chat_log, result) + user_input.agent_id, _transform_stream(chat_log, result, messages) ): - messages.extend(_convert_content_to_param(content)) + if not isinstance(content, conversation.AssistantContent): + messages.extend(_convert_content_to_param(content)) if not chat_log.unresponded_tool_results: break diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index d6f09e0f30e..269590b483a 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -586,6 +586,11 @@ async def test_function_call( agent_id="conversation.openai", ) + assert mock_create_stream.call_args.kwargs["input"][2] == { + "id": "rs_A", + "summary": [], + "type": "reasoning", + } assert result.response.response_type == intent.IntentResponseType.ACTION_DONE # Don't test the prompt, as it's not deterministic assert mock_chat_log.content[1:] == snapshot