openai: support include_usage stream option to return final usage chunk
This commit is contained in:
parent
fda0d3be52
commit
220108d3f4
131
openai/openai.go
131
openai/openai.go
@ -61,6 +61,21 @@ type Usage struct {
|
|||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChunkUsage is an alias for Usage with the ability to marshal a marker
|
||||||
|
// value as null. This is to allow omitting the field in chunks when usage
|
||||||
|
// isn't requested, and otherwise return null on non-final chunks when it
|
||||||
|
// is requested to follow OpenAI's behavior.
|
||||||
|
type ChunkUsage = Usage
|
||||||
|
|
||||||
|
var nullChunkUsage = ChunkUsage{}
|
||||||
|
|
||||||
|
func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
|
||||||
|
if u == &nullChunkUsage {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(*u)
|
||||||
|
}
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
@ -70,10 +85,15 @@ type EmbedRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
MaxTokens *int `json:"max_tokens"`
|
MaxTokens *int `json:"max_tokens"`
|
||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
@ -102,21 +122,23 @@ type ChatCompletionChunk struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
|
Usage *ChunkUsage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
MaxTokens *int `json:"max_tokens"`
|
MaxTokens *int `json:"max_tokens"`
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
TopP float32 `json:"top_p"`
|
Temperature *float32 `json:"temperature"`
|
||||||
Suffix string `json:"suffix"`
|
TopP float32 `json:"top_p"`
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
@ -136,6 +158,7 @@ type CompletionChunk struct {
|
|||||||
Choices []CompleteChunkChoice `json:"choices"`
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Usage *ChunkUsage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
@ -200,6 +223,14 @@ func toolCallId() string {
|
|||||||
return "call_" + strings.ToLower(string(b))
|
return "call_" + strings.ToLower(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsage(r api.ChatResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||||
for i, tc := range r.Message.ToolCalls {
|
for i, tc := range r.Message.ToolCalls {
|
||||||
@ -235,11 +266,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsage(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +290,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return Completion{
|
return Completion{
|
||||||
Id: id,
|
Id: id,
|
||||||
@ -280,11 +315,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsageGenerate(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -546,14 +577,16 @@ type BaseWriter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatWriter struct {
|
type ChatWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
id string
|
streamUsage bool
|
||||||
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompleteWriter struct {
|
type CompleteWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
id string
|
streamUsage bool
|
||||||
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -596,7 +629,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// chat chunk
|
// chat chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
c := toChunk(w.id, chatResponse)
|
||||||
|
if w.streamUsage {
|
||||||
|
c.Usage = &nullChunkUsage
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -608,6 +645,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if chatResponse.Done {
|
if chatResponse.Done {
|
||||||
|
if w.streamUsage {
|
||||||
|
u := toUsage(chatResponse)
|
||||||
|
d, err := json.Marshal(ChatCompletionChunk{Usage: &u})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -645,7 +693,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// completion chunk
|
// completion chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
c := toCompleteChunk(w.id, generateResponse)
|
||||||
|
if w.streamUsage {
|
||||||
|
c.Usage = &nullChunkUsage
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -657,6 +709,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if generateResponse.Done {
|
if generateResponse.Done {
|
||||||
|
if w.streamUsage {
|
||||||
|
u := toUsageGenerate(generateResponse)
|
||||||
|
d, err := json.Marshal(CompletionChunk{Usage: &u})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -819,9 +882,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
|||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &CompleteWriter{
|
w := &CompleteWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
@ -901,9 +965,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &ChatWriter{
|
w := &ChatWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
@ -111,6 +111,45 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
Stream: &True,
|
Stream: &True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: "json",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chat handler with image content",
|
name: "chat handler with image content",
|
||||||
body: `{
|
body: `{
|
||||||
@ -283,6 +322,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream with usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "completions handler error forwarding",
|
name: "completions handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user