diff --git a/openai/openai.go b/openai/openai.go index 3a35d9dda..caf1152df 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { Index: 0, Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, FinishReason: func(reason string) *string { + if len(toolCalls) > 0 { + reason = "tool_calls" + } if len(reason) > 0 { return &reason } @@ -570,8 +573,9 @@ type BaseWriter struct { } type ChatWriter struct { - stream bool - id string + stream bool + finished bool + id string BaseWriter } @@ -620,15 +624,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { - d, err := json.Marshal(toChunk(w.id, chatResponse)) - if err != nil { - return 0, err - } + // If we've already finished, don't send any more chunks with choices. + if !w.finished { + chunk := toChunk(w.id, chatResponse) + d, err := json.Marshal(chunk) + if err != nil { + return 0, err + } - w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } } if chatResponse.Done {