mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Raise error when max tokens reached for openai_conversation (#140214)
* Handle ResponseIncompleteEvent * Updated error text * Fix tests * Update conversation.py * ruff * More tests * Handle ResponseFailed and ResponseError --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
90623bbaff
commit
cb9692f3fb
@ -10,10 +10,13 @@ from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseCompletedEvent,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallParam,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputMessage,
|
||||
@ -139,18 +142,57 @@ async def _transform_stream(
|
||||
)
|
||||
]
|
||||
}
|
||||
elif (
|
||||
isinstance(event, ResponseCompletedEvent)
|
||||
and (usage := event.response.usage) is not None
|
||||
):
|
||||
elif isinstance(event, ResponseCompletedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(event, ResponseIncompleteEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
event.response.incomplete_details
|
||||
and event.response.incomplete_details.reason
|
||||
):
|
||||
reason: str = event.response.incomplete_details.reason
|
||||
else:
|
||||
reason = "unknown reason"
|
||||
|
||||
if reason == "max_output_tokens":
|
||||
reason = "max output tokens reached"
|
||||
elif reason == "content_filter":
|
||||
reason = "content filter triggered"
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||
elif isinstance(event, ResponseFailedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
reason = "unknown reason"
|
||||
if event.response.error is not None:
|
||||
reason = event.response.error.message
|
||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||
elif isinstance(event, ResponseErrorEvent):
|
||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
|
@ -12,9 +12,13 @@ from openai.types.responses import (
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseError,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
@ -26,6 +30,7 @@ from openai.types.responses import (
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
)
|
||||
from openai.types.responses.response import IncompleteDetails
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
@ -83,13 +88,36 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
response=response,
|
||||
type="response.in_progress",
|
||||
)
|
||||
response.status = "completed"
|
||||
|
||||
for value in events:
|
||||
if isinstance(value, ResponseOutputItemDoneEvent):
|
||||
response.output.append(value.item)
|
||||
elif isinstance(value, IncompleteDetails):
|
||||
response.status = "incomplete"
|
||||
response.incomplete_details = value
|
||||
break
|
||||
if isinstance(value, ResponseError):
|
||||
response.status = "failed"
|
||||
response.error = value
|
||||
break
|
||||
|
||||
yield value
|
||||
|
||||
response.status = "completed"
|
||||
if isinstance(value, ResponseErrorEvent):
|
||||
return
|
||||
|
||||
if response.status == "incomplete":
|
||||
yield ResponseIncompleteEvent(
|
||||
response=response,
|
||||
type="response.incomplete",
|
||||
)
|
||||
elif response.status == "failed":
|
||||
yield ResponseFailedEvent(
|
||||
response=response,
|
||||
type="response.failed",
|
||||
)
|
||||
else:
|
||||
yield ResponseCompletedEvent(
|
||||
response=response,
|
||||
type="response.completed",
|
||||
@ -175,6 +203,123 @@ async def test_error_handling(
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("reason", "message"),
|
||||
[
|
||||
(
|
||||
"max_output_tokens",
|
||||
"max output tokens reached",
|
||||
),
|
||||
(
|
||||
"content_filter",
|
||||
"content filter triggered",
|
||||
),
|
||||
(
|
||||
None,
|
||||
"unknown reason",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_incomplete_response(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
reason: str,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Test handling early model stop."""
|
||||
# Incomplete details received after some content is generated
|
||||
mock_create_stream.return_value = [
|
||||
(
|
||||
# Start message
|
||||
*create_message_item(
|
||||
id="msg_A",
|
||||
text=["Once upon", " a time, ", "there was "],
|
||||
output_index=0,
|
||||
),
|
||||
# Length limit or content filter
|
||||
IncompleteDetails(reason=reason),
|
||||
)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please tell me a big story",
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== f"OpenAI response incomplete: {message}"
|
||||
), result.response.speech
|
||||
|
||||
# Incomplete details received before any content is generated
|
||||
mock_create_stream.return_value = [
|
||||
(
|
||||
# Start generating response
|
||||
*create_reasoning_item(id="rs_A", output_index=0),
|
||||
# Length limit or content filter
|
||||
IncompleteDetails(reason=reason),
|
||||
)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"please tell me a big story",
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== f"OpenAI response incomplete: {message}"
|
||||
), result.response.speech
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "message"),
|
||||
[
|
||||
(
|
||||
ResponseError(code="rate_limit_exceeded", message="Rate limit exceeded"),
|
||||
"OpenAI response failed: Rate limit exceeded",
|
||||
),
|
||||
(
|
||||
ResponseErrorEvent(type="error", message="Some error"),
|
||||
"OpenAI response error: Some error",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_failed_response(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
error: ResponseError | ResponseErrorEvent,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Test handling failed and error responses."""
|
||||
mock_create_stream.return_value = [(error,)]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"next natural number please",
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
Loading…
x
Reference in New Issue
Block a user