Code cleanup

This commit is contained in:
ParthSareen 2024-12-11 18:04:16 -08:00
parent c6509bf76e
commit 97abd7bfea

View File

@ -67,14 +67,14 @@ type Usage struct {
// is requested to follow OpenAI's behavior. // is requested to follow OpenAI's behavior.
type ChunkUsage = Usage type ChunkUsage = Usage
var nullChunkUsage = ChunkUsage{} // var nullChunkUsage = ChunkUsage{}
func (u *ChunkUsage) MarshalJSON() ([]byte, error) { // func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
if u == &nullChunkUsage { // if u == &nullChunkUsage {
return []byte("null"), nil // return []byte("null"), nil
} // }
return json.Marshal(*u) // return json.Marshal(*u)
} // }
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
@ -602,14 +602,14 @@ type BaseWriter struct {
type ChatWriter struct { type ChatWriter struct {
stream bool stream bool
streamUsage bool streamOptions *StreamOptions
id string id string
BaseWriter BaseWriter
} }
type CompleteWriter struct { type CompleteWriter struct {
stream bool stream bool
streamUsage bool streamOptions *StreamOptions
id string id string
BaseWriter BaseWriter
} }
@ -654,8 +654,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk // chat chunk
if w.stream { if w.stream {
c := toChunk(w.id, chatResponse) c := toChunk(w.id, chatResponse)
if w.streamUsage { if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &nullChunkUsage c.Usage = &ChunkUsage{}
} }
d, err := json.Marshal(c) d, err := json.Marshal(c)
if err != nil { if err != nil {
@ -669,7 +669,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
} }
if chatResponse.Done { if chatResponse.Done {
if w.streamUsage { if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse) u := toUsage(chatResponse)
d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u}) d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
if err != nil { if err != nil {
@ -718,8 +718,8 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
// completion chunk // completion chunk
if w.stream { if w.stream {
c := toCompleteChunk(w.id, generateResponse) c := toCompleteChunk(w.id, generateResponse)
if w.streamUsage { if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &nullChunkUsage c.Usage = &ChunkUsage{}
} }
d, err := json.Marshal(c) d, err := json.Marshal(c)
if err != nil { if err != nil {
@ -733,7 +733,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
} }
if generateResponse.Done { if generateResponse.Done {
if w.streamUsage { if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse) u := toUsageGenerate(generateResponse)
d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u}) d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
if err != nil { if err != nil {
@ -909,7 +909,7 @@ func CompletionsMiddleware() gin.HandlerFunc {
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage, streamOptions: req.StreamOptions,
} }
c.Writer = w c.Writer = w
@ -992,7 +992,7 @@ func ChatMiddleware() gin.HandlerFunc {
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage, streamOptions: req.StreamOptions,
} }
c.Writer = w c.Writer = w