openai: fix follow-on messages having "role": "assistant"

This commit is contained in:
jmorganca 2024-11-18 00:04:17 -08:00
parent a14f76491d
commit 32c48ddad6

View File

@ -32,7 +32,7 @@ type ErrorResponse struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role,omitempty"`
Content any `json:"content"` Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
@ -252,7 +252,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content}, Delta: Message{Content: r.Message.Content},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@ -546,8 +546,9 @@ type BaseWriter struct {
} }
type ChatWriter struct { type ChatWriter struct {
stream bool stream bool
id string started bool
id string
BaseWriter BaseWriter
} }
@ -594,8 +595,28 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
return 0, err return 0, err
} }
// chat chunk
if w.stream { if w.stream {
// The first chunk always has empty content so we
// copy the first chunk and set the content to
// empty, and send it first.
if !w.started {
first := chatResponse
first.Message = api.Message{Role: "assistant"}
d, err := json.Marshal(toChunk(w.id, first))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
w.started = true
}
d, err := json.Marshal(toChunk(w.id, chatResponse)) d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil { if err != nil {
return 0, err return 0, err
@ -617,7 +638,6 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
return len(data), nil return len(data), nil
} }
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil { if err != nil {