mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Preserve reasoning during tool calls for openai_conversation (#143699)
Preserve reasoning after tool calls for openai_conversation
This commit is contained in:
parent
4c9cd70f65
commit
7074331461
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai._streaming import AsyncStream
|
from openai._streaming import AsyncStream
|
||||||
@ -19,7 +19,11 @@ from openai.types.responses import (
|
|||||||
ResponseIncompleteEvent,
|
ResponseIncompleteEvent,
|
||||||
ResponseInputParam,
|
ResponseInputParam,
|
||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
ResponseOutputMessage,
|
ResponseOutputMessage,
|
||||||
|
ResponseOutputMessageParam,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseReasoningItemParam,
|
||||||
ResponseStreamEvent,
|
ResponseStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ToolParam,
|
ToolParam,
|
||||||
@ -127,6 +131,7 @@ def _convert_content_to_param(
|
|||||||
async def _transform_stream(
|
async def _transform_stream(
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
result: AsyncStream[ResponseStreamEvent],
|
result: AsyncStream[ResponseStreamEvent],
|
||||||
|
messages: ResponseInputParam,
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
"""Transform an OpenAI delta stream into HA format."""
|
"""Transform an OpenAI delta stream into HA format."""
|
||||||
async for event in result:
|
async for event in result:
|
||||||
@ -137,6 +142,15 @@ async def _transform_stream(
|
|||||||
yield {"role": event.item.role}
|
yield {"role": event.item.role}
|
||||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||||
current_tool_call = event.item
|
current_tool_call = event.item
|
||||||
|
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||||
|
item = event.item.model_dump()
|
||||||
|
item.pop("status", None)
|
||||||
|
if isinstance(event.item, ResponseReasoningItem):
|
||||||
|
messages.append(cast(ResponseReasoningItemParam, item))
|
||||||
|
elif isinstance(event.item, ResponseOutputMessage):
|
||||||
|
messages.append(cast(ResponseOutputMessageParam, item))
|
||||||
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||||
|
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||||
elif isinstance(event, ResponseTextDeltaEvent):
|
elif isinstance(event, ResponseTextDeltaEvent):
|
||||||
yield {"content": event.delta}
|
yield {"content": event.delta}
|
||||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||||
@ -314,7 +328,6 @@ class OpenAIConversationEntity(
|
|||||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
"user": chat_log.conversation_id,
|
"user": chat_log.conversation_id,
|
||||||
"store": False,
|
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
@ -326,6 +339,8 @@ class OpenAIConversationEntity(
|
|||||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
model_args["store"] = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await client.responses.create(**model_args)
|
result = await client.responses.create(**model_args)
|
||||||
@ -337,9 +352,10 @@ class OpenAIConversationEntity(
|
|||||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||||
|
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
user_input.agent_id, _transform_stream(chat_log, result)
|
user_input.agent_id, _transform_stream(chat_log, result, messages)
|
||||||
):
|
):
|
||||||
messages.extend(_convert_content_to_param(content))
|
if not isinstance(content, conversation.AssistantContent):
|
||||||
|
messages.extend(_convert_content_to_param(content))
|
||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
@ -586,6 +586,11 @@ async def test_function_call(
|
|||||||
agent_id="conversation.openai",
|
agent_id="conversation.openai",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert mock_create_stream.call_args.kwargs["input"][2] == {
|
||||||
|
"id": "rs_A",
|
||||||
|
"summary": [],
|
||||||
|
"type": "reasoning",
|
||||||
|
}
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
# Don't test the prompt, as it's not deterministic
|
# Don't test the prompt, as it's not deterministic
|
||||||
assert mock_chat_log.content[1:] == snapshot
|
assert mock_chat_log.content[1:] == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user