mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Preserve reasoning during tool calls for openai_conversation (#143699)
Preserve reasoning after tool calls for openai_conversation
This commit is contained in:
parent
4c9cd70f65
commit
7074331461
@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
@ -19,7 +19,11 @@ from openai.types.responses import (
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputMessageParam,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningItemParam,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolParam,
|
||||
@ -127,6 +131,7 @@ def _convert_content_to_param(
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
result: AsyncStream[ResponseStreamEvent],
|
||||
messages: ResponseInputParam,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
async for event in result:
|
||||
@ -137,6 +142,15 @@ async def _transform_stream(
|
||||
yield {"role": event.item.role}
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
current_tool_call = event.item
|
||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||
item = event.item.model_dump()
|
||||
item.pop("status", None)
|
||||
if isinstance(event.item, ResponseReasoningItem):
|
||||
messages.append(cast(ResponseReasoningItemParam, item))
|
||||
elif isinstance(event.item, ResponseOutputMessage):
|
||||
messages.append(cast(ResponseOutputMessageParam, item))
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
@ -314,7 +328,6 @@ class OpenAIConversationEntity(
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
@ -326,6 +339,8 @@ class OpenAIConversationEntity(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
}
|
||||
else:
|
||||
model_args["store"] = False
|
||||
|
||||
try:
|
||||
result = await client.responses.create(**model_args)
|
||||
@ -337,8 +352,9 @@ class OpenAIConversationEntity(
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(chat_log, result)
|
||||
user_input.agent_id, _transform_stream(chat_log, result, messages)
|
||||
):
|
||||
if not isinstance(content, conversation.AssistantContent):
|
||||
messages.extend(_convert_content_to_param(content))
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
|
@ -586,6 +586,11 @@ async def test_function_call(
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert mock_create_stream.call_args.kwargs["input"][2] == {
|
||||
"id": "rs_A",
|
||||
"summary": [],
|
||||
"type": "reasoning",
|
||||
}
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
# Don't test the prompt, as it's not deterministic
|
||||
assert mock_chat_log.content[1:] == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user