From afa2e855d4f5747c69487a25e94220e3b8edc0eb Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Fri, 10 Jan 2025 11:15:31 -0800 Subject: [PATCH] log probs working --- api/client.go | 2 +- api/types.go | 17 ++++----- llama/llama.go | 12 +++++++ llama/runner/runner.go | 81 ++++++++++++++++++++---------------------- server/routes.go | 73 ++++++++++++++++++++++++++++++------- 5 files changed, 119 insertions(+), 66 deletions(-) diff --git a/api/client.go b/api/client.go index 4688d4d13..2bba62a33 100644 --- a/api/client.go +++ b/api/client.go @@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData return nil } -const maxBufferSize = 512 * format.KiloByte +const maxBufferSize = 1024 * format.KiloByte func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { var buf *bytes.Buffer diff --git a/api/types.go b/api/types.go index c9497a40e..a1c63fb68 100644 --- a/api/types.go +++ b/api/types.go @@ -189,11 +189,12 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` - Logits []float32 `json:"logits"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + Logits []float32 `json:"logits"` + TopLogprobs []TokenLogprob `json:"top_logprobs"` Done bool `json:"done"` @@ -210,14 +211,10 @@ type Metrics struct { } type TokenLogprob struct { - Token string `json:"token"` + Text string `json:"text"` Logprob float32 `json:"logprob"` } -type LogProbs struct { - TopLogprobs []TokenLogprob `json:"top_logprobs"` -} - // Options specified in [GenerateRequest]. If you add a new option here, also // add it to the API docs. type Options struct { diff --git a/llama/llama.go b/llama/llama.go index 25bc67086..dc83ed91d 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -273,6 +273,18 @@ func (c *Context) GetLogits() []float32 { return unsafe.Slice((*float32)(logits), vocabSize) } +func (m *Model) Detokenize(tokens []int) (string, error) { + var text string + for _, token := range tokens { + piece := m.TokenToPiece(token) + if piece == "" { + return "", fmt.Errorf("failed to convert token %d to piece", token) + } + text += piece + } + return text, nil +} + type ModelParams struct { NumGpuLayers int MainGpu int diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e806f1ba3..d40895fac 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -15,7 +15,6 @@ import ( "path/filepath" "regexp" "runtime" - "sort" "strconv" "strings" "sync" @@ -503,9 +502,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if seq.returnLogits { // New flag we need to add to Sequence struct logits := s.lc.GetLogits() seq.logits = make([]float32, len(logits)) - slog.Info("copying logits") copy(seq.logits, logits) - slog.Info("copying logits success") } // Then sample token @@ -608,7 +605,7 @@ type CompletionRequest struct { Images []ImageData `json:"image_data"` Grammar string `json:"grammar"` CachePrompt bool `json:"cache_prompt"` - ReturnLogits bool `json:"return_logits"` + ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false Options } @@ -729,7 +726,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - slog.Info("content", "content", content.Content) + // slog.Info("content", "content", content.Content) if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -1040,50 +1037,50 @@ func Execute(args []string) error { return nil } -// Helper function to get top K logits and convert to log probabilities -func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs { - if k <= 0 { - return nil - } +// // Helper function to get top K logits and convert to log probabilities +// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs { +// if k <= 0 { +// return nil +// } - // Convert logits to probabilities using softmax - probs := softmax(logits) +// // Convert logits to probabilities using softmax +// probs := softmax(logits) - // Create slice of index/probability pairs - pairs := make([]struct { - token int - prob float32 - }, len(probs)) +// // Create slice of index/probability pairs +// pairs := make([]struct { +// token int +// prob float32 +// }, len(probs)) - for i, p := range probs { - pairs[i] = struct { - token int - prob float32 - }{i, p} - } +// for i, p := range probs { +// pairs[i] = struct { +// token int +// prob float32 +// }{i, p} +// } - // Sort by probability (descending) - sort.Slice(pairs, func(i, j int) bool { - return pairs[i].prob > pairs[j].prob - }) +// // Sort by probability (descending) +// sort.Slice(pairs, func(i, j int) bool { +// return pairs[i].prob > pairs[j].prob +// }) - // Take top K - k = min(k, len(pairs)) - result := make([]api.LogProbs, k) +// // Take top K +// k = min(k, len(pairs)) +// result := make([]api.LogProbs, k) - for i := 0; i < k; i++ { - result[i] = api.LogProbs{ - TopLogprobs: []api.TokenLogprob{ - { - Token: model.TokenToPiece(pairs[i].token), - Logprob: float32(math.Log(float64(pairs[i].prob))), - }, - }, - } - } +// for i := 0; i < k; i++ { +// result[i] = api.LogProbs{ +// TopLogprobs: []api.TokenLogprob{ +// { +// Token: model.TokenToPiece(pairs[i].token), +// Logprob: float32(math.Log(float64(pairs[i].prob))), +// }, +// }, +// } +// } - return result -} +// return result +// } // Helper function to compute softmax func softmax(logits []float32) []float32 { diff --git a/server/routes.go b/server/routes.go index dc45b56a9..28762b370 100644 --- a/server/routes.go +++ b/server/routes.go @@ -19,6 +19,7 @@ import ( "os/signal" "path/filepath" "slices" + "sort" "strings" "syscall" "time" @@ -299,7 +300,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - ReturnLogits: req.ReturnLogits, + ReturnLogits: false, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1554,23 +1555,27 @@ func (s *Server) ChatHandler(c *gin.Context) { Format: req.Format, Options: opts, ReturnLogits: true, - }, func(r llm.CompletionResponse) { + }, func(cr llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - DoneReason: r.DoneReason, - Logits: r.Logits, + Message: api.Message{Role: "assistant", Content: cr.Content}, + Done: cr.Done, + DoneReason: cr.DoneReason, + Logits: []float32{}, Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, + PromptEvalCount: cr.PromptEvalCount, + PromptEvalDuration: cr.PromptEvalDuration, + EvalCount: cr.EvalCount, + EvalDuration: cr.EvalDuration, }, } - if r.Done { + topK := int(3) + logits := make([]float32, len(cr.Logits)) + copy(logits, cr.Logits) + res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK) + if cr.Done { res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } @@ -1586,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) { // Streaming tool calls: // If tools are recognized, use a flag to track the sending of a tool downstream // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) + sb.WriteString(cr.Content) if toolCalls, ok := m.parseToolCalls(sb.String()); ok { res.Message.ToolCalls = toolCalls for i := range toolCalls { @@ -1599,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if r.Done { + if cr.Done { // Send any remaining content if no tool calls were detected if toolCallIndex == 0 { res.Message.Content = sb.String() @@ -1649,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) { streamResponse(c, ch) } +func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob { + // Calculate softmax denominator first (log sum exp trick for numerical stability) + maxLogit := float32(math.Inf(-1)) + for _, logit := range logits { + if logit > maxLogit { + maxLogit = logit + } + } + + var sumExp float32 + for _, logit := range logits { + sumExp += float32(math.Exp(float64(logit - maxLogit))) + } + logSumExp := float32(math.Log(float64(sumExp))) + maxLogit + + // Calculate log probs and track top K + logProbs := make([]api.TokenLogprob, len(logits)) + for i, logit := range logits { + text, err := s.Detokenize(ctx, []int{i}) + if err != nil { + slog.Error("detokenize error for logprob", "error", err) + continue + } + + logProbs[i] = api.TokenLogprob{ + Text: text, + Logprob: logit - logSumExp, + } + } + + // Sort by logprob descending and take top K + sort.Slice(logProbs, func(i, j int) bool { + return logProbs[i].Logprob > logProbs[j].Logprob + }) + + if len(logProbs) > topK { + logProbs = logProbs[:topK] + } + + return logProbs +} + func handleScheduleError(c *gin.Context, name string, err error) { switch { case errors.Is(err, errCapabilities), errors.Is(err, errRequired):