From 32c48ddad66a6132c7ddcbebdd9548dcd87a0eea Mon Sep 17 00:00:00 2001 From: jmorganca Date: Mon, 18 Nov 2024 00:04:17 -0800 Subject: [PATCH] openai: fix follow-on messages having "role": "assistant" --- openai/openai.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 2bf9b9f9b..881b76c2b 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -32,7 +32,7 @@ type ErrorResponse struct { } type Message struct { - Role string `json:"role"` + Role string `json:"role,omitempty"` Content any `json:"content"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } @@ -252,7 +252,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, - Delta: Message{Role: "assistant", Content: r.Message.Content}, + Delta: Message{Content: r.Message.Content}, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason @@ -546,8 +546,9 @@ type BaseWriter struct { } type ChatWriter struct { - stream bool - id string + stream bool + started bool + id string BaseWriter } @@ -594,8 +595,28 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { return 0, err } - // chat chunk if w.stream { + // The first chunk always has empty content so we + // copy the first chunk and set the content to + // empty, and send it first. + if !w.started { + first := chatResponse + first.Message = api.Message{Role: "assistant"} + + d, err := json.Marshal(toChunk(w.id, first)) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + w.started = true + } + d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err @@ -617,7 +638,6 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { return len(data), nil } - // chat completion w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) if err != nil {