Compare commits

...

1 Commits

Author SHA1 Message Date
Anuraag Agrawal
120785dbb6 openai: finish streaming tool calls as tool_calls 2024-12-06 15:00:48 +09:00

View File

@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(toolCalls) > 0 {
reason = "tool_calls"
}
if len(reason) > 0 {
return &reason
}
@ -570,8 +573,9 @@ type BaseWriter struct {
}
type ChatWriter struct {
stream bool
id string
stream bool
finished bool
id string
BaseWriter
}
@ -620,15 +624,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk
if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil {
return 0, err
}
// If we've already finished, don't send any more chunks with choices.
if !w.finished {
chunk := toChunk(w.id, chatResponse)
d, err := json.Marshal(chunk)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
if chatResponse.Done {