openai: finish streaming tool calls as tool_calls

This commit is contained in:
Anuraag Agrawal 2024-12-06 13:50:20 +09:00
parent aed1419c64
commit 120785dbb6

View File

@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(toolCalls) > 0 {
reason = "tool_calls"
}
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
} }
@ -571,6 +574,7 @@ type BaseWriter struct {
type ChatWriter struct { type ChatWriter struct {
stream bool stream bool
finished bool
id string id string
BaseWriter BaseWriter
} }
@ -620,7 +624,10 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk // chat chunk
if w.stream { if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse)) // If we've already finished, don't send any more chunks with choices.
if !w.finished {
chunk := toChunk(w.id, chatResponse)
d, err := json.Marshal(chunk)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -630,6 +637,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
}
if chatResponse.Done { if chatResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))