From cb9692f3fb7a6b7afbe1a645dfe3f6d21f3c291d Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Mon, 24 Mar 2025 17:49:34 +0300 Subject: [PATCH] Raise error when max tokens reached for openai_conversation (#140214) * Handle ResponseIncompleteEvent * Updated error text * Fix tests * Update conversation.py * ruff * More tests * Handle ResponseFailed and ResponseError --------- Co-authored-by: Paulus Schoutsen Co-authored-by: Paulus Schoutsen --- .../openai_conversation/conversation.py | 64 ++++++-- .../openai_conversation/test_conversation.py | 155 +++++++++++++++++- 2 files changed, 203 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 32ac20b2680..873406a3999 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -10,10 +10,13 @@ from openai.types.responses import ( EasyInputMessageParam, FunctionToolParam, ResponseCompletedEvent, + ResponseErrorEvent, + ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionToolCallParam, + ResponseIncompleteEvent, ResponseInputParam, ResponseOutputItemAddedEvent, ResponseOutputMessage, @@ -139,18 +142,57 @@ async def _transform_stream( ) ] } - elif ( - isinstance(event, ResponseCompletedEvent) - and (usage := event.response.usage) is not None - ): - chat_log.async_trace( - { - "stats": { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, + elif isinstance(event, ResponseCompletedEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } } - } - ) + ) + elif isinstance(event, ResponseIncompleteEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + + if ( + event.response.incomplete_details + and event.response.incomplete_details.reason + ): + reason: str = event.response.incomplete_details.reason + else: + reason = "unknown reason" + + if reason == "max_output_tokens": + reason = "max output tokens reached" + elif reason == "content_filter": + reason = "content filter triggered" + + raise HomeAssistantError(f"OpenAI response incomplete: {reason}") + elif isinstance(event, ResponseFailedEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + reason = "unknown reason" + if event.response.error is not None: + reason = event.response.error.message + raise HomeAssistantError(f"OpenAI response failed: {reason}") + elif isinstance(event, ResponseErrorEvent): + raise HomeAssistantError(f"OpenAI response error: {event.message}") class OpenAIConversationEntity( diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index bfcacefb044..fb54c423234 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -12,9 +12,13 @@ from openai.types.responses import ( ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, + ResponseIncompleteEvent, ResponseInProgressEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, @@ -26,6 +30,7 @@ from openai.types.responses import ( ResponseTextDeltaEvent, ResponseTextDoneEvent, ) +from openai.types.responses.response import IncompleteDetails import pytest from syrupy.assertion import SnapshotAssertion @@ -83,17 +88,40 @@ def mock_create_stream() -> Generator[AsyncMock]: response=response, type="response.in_progress", ) + response.status = "completed" for value in events: if isinstance(value, ResponseOutputItemDoneEvent): response.output.append(value.item) + elif isinstance(value, IncompleteDetails): + response.status = "incomplete" + response.incomplete_details = value + break + if isinstance(value, ResponseError): + response.status = "failed" + response.error = value + break + yield value - response.status = "completed" - yield ResponseCompletedEvent( - response=response, - type="response.completed", - ) + if isinstance(value, ResponseErrorEvent): + return + + if response.status == "incomplete": + yield ResponseIncompleteEvent( + response=response, + type="response.incomplete", + ) + elif response.status == "failed": + yield ResponseFailedEvent( + response=response, + type="response.failed", + ) + else: + yield ResponseCompletedEvent( + response=response, + type="response.completed", + ) with patch( "openai.resources.responses.AsyncResponses.create", @@ -175,6 +203,123 @@ async def test_error_handling( assert result.response.speech["plain"]["speech"] == message, result.response.speech +@pytest.mark.parametrize( + ("reason", "message"), + [ + ( + "max_output_tokens", + "max output tokens reached", + ), + ( + "content_filter", + "content filter triggered", + ), + ( + None, + "unknown reason", + ), + ], +) +async def test_incomplete_response( + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, + mock_create_stream: AsyncMock, + mock_chat_log: MockChatLog, # noqa: F811 + reason: str, + message: str, +) -> None: + """Test handling early model stop.""" + # Incomplete details received after some content is generated + mock_create_stream.return_value = [ + ( + # Start message + *create_message_item( + id="msg_A", + text=["Once upon", " a time, ", "there was "], + output_index=0, + ), + # Length limit or content filter + IncompleteDetails(reason=reason), + ) + ] + + result = await conversation.async_converse( + hass, + "Please tell me a big story", + "mock-conversation-id", + Context(), + agent_id="conversation.openai", + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert ( + result.response.speech["plain"]["speech"] + == f"OpenAI response incomplete: {message}" + ), result.response.speech + + # Incomplete details received before any content is generated + mock_create_stream.return_value = [ + ( + # Start generating response + *create_reasoning_item(id="rs_A", output_index=0), + # Length limit or content filter + IncompleteDetails(reason=reason), + ) + ] + + result = await conversation.async_converse( + hass, + "please tell me a big story", + "mock-conversation-id", + Context(), + agent_id="conversation.openai", + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert ( + result.response.speech["plain"]["speech"] + == f"OpenAI response incomplete: {message}" + ), result.response.speech + + +@pytest.mark.parametrize( + ("error", "message"), + [ + ( + ResponseError(code="rate_limit_exceeded", message="Rate limit exceeded"), + "OpenAI response failed: Rate limit exceeded", + ), + ( + ResponseErrorEvent(type="error", message="Some error"), + "OpenAI response error: Some error", + ), + ], +) +async def test_failed_response( + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, + mock_create_stream: AsyncMock, + mock_chat_log: MockChatLog, # noqa: F811 + error: ResponseError | ResponseErrorEvent, + message: str, +) -> None: + """Test handling failed and error responses.""" + mock_create_stream.return_value = [(error,)] + + result = await conversation.async_converse( + hass, + "next natural number please", + "mock-conversation-id", + Context(), + agent_id="conversation.openai", + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.speech["plain"]["speech"] == message, result.response.speech + + async def test_conversation_agent( hass: HomeAssistant, mock_config_entry: MockConfigEntry,