openai: support include_usage stream option to return final usage chunk

This commit is contained in:
Anuraag Agrawal 2024-09-13 12:24:43 +09:00
parent fda0d3be52
commit 220108d3f4
2 changed files with 186 additions and 33 deletions

View File

@ -61,6 +61,21 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}
// ChunkUsage is an alias for Usage with the ability to marshal a marker
// value as null. This is to allow omitting the field in chunks when usage
// isn't requested, and otherwise return null on non-final chunks when it
// is requested to follow OpenAI's behavior.
type ChunkUsage = Usage
var nullChunkUsage = ChunkUsage{}
func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
if u == &nullChunkUsage {
return []byte("null"), nil
}
return json.Marshal(*u)
}
type ResponseFormat struct {
Type string `json:"type"`
}
@ -70,10 +85,15 @@ type EmbedRequest struct {
Model string `json:"model"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
@ -102,21 +122,23 @@ type ChatCompletionChunk struct {
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
Usage *ChunkUsage `json:"usage,omitempty"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
@ -136,6 +158,7 @@ type CompletionChunk struct {
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Usage *ChunkUsage `json:"usage,omitempty"`
}
type ToolCall struct {
@ -200,6 +223,14 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
func toUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
@ -235,11 +266,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
Usage: toUsage(r),
}
}
@ -263,6 +290,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
@ -280,11 +315,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
Usage: toUsageGenerate(r),
}
}
@ -546,14 +577,16 @@ type BaseWriter struct {
}
type ChatWriter struct {
stream bool
id string
stream bool
streamUsage bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
stream bool
streamUsage bool
id string
BaseWriter
}
@ -596,7 +629,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk
if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse))
c := toChunk(w.id, chatResponse)
if w.streamUsage {
c.Usage = &nullChunkUsage
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
@ -608,6 +645,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
}
if chatResponse.Done {
if w.streamUsage {
u := toUsage(chatResponse)
d, err := json.Marshal(ChatCompletionChunk{Usage: &u})
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@ -645,7 +693,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
c := toCompleteChunk(w.id, generateResponse)
if w.streamUsage {
c.Usage = &nullChunkUsage
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
@ -657,6 +709,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
}
if generateResponse.Done {
if w.streamUsage {
u := toUsageGenerate(generateResponse)
d, err := json.Marshal(CompletionChunk{Usage: &u})
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@ -819,9 +882,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
}
c.Writer = w
@ -901,9 +965,10 @@ func ChatMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
}
c.Writer = w

View File

@ -111,6 +111,45 @@ func TestChatMiddleware(t *testing.T) {
Stream: &True,
},
},
{
name: "chat handler with streaming usage",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"stream_options": {"include_usage": true},
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: "json",
Stream: &True,
},
},
{
name: "chat handler with image content",
body: `{
@ -283,6 +322,55 @@ func TestCompletionsMiddleware(t *testing.T) {
Stream: &False,
},
},
{
name: "completions handler stream",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler stream with usage",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler error forwarding",
body: `{