log probs working

This commit is contained in:
ParthSareen 2025-01-10 11:15:31 -08:00
parent f9928b677f
commit afa2e855d4
5 changed files with 119 additions and 66 deletions

View File

@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 1024 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer

View File

@ -189,11 +189,12 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"` Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"` DoneReason string `json:"done_reason,omitempty"`
Logits []float32 `json:"logits"` Logits []float32 `json:"logits"`
TopLogprobs []TokenLogprob `json:"top_logprobs"`
Done bool `json:"done"` Done bool `json:"done"`
@ -210,14 +211,10 @@ type Metrics struct {
} }
type TokenLogprob struct { type TokenLogprob struct {
Token string `json:"token"` Text string `json:"text"`
Logprob float32 `json:"logprob"` Logprob float32 `json:"logprob"`
} }
type LogProbs struct {
TopLogprobs []TokenLogprob `json:"top_logprobs"`
}
// Options specified in [GenerateRequest]. If you add a new option here, also // Options specified in [GenerateRequest]. If you add a new option here, also
// add it to the API docs. // add it to the API docs.
type Options struct { type Options struct {

View File

@ -273,6 +273,18 @@ func (c *Context) GetLogits() []float32 {
return unsafe.Slice((*float32)(logits), vocabSize) return unsafe.Slice((*float32)(logits), vocabSize)
} }
func (m *Model) Detokenize(tokens []int) (string, error) {
var text string
for _, token := range tokens {
piece := m.TokenToPiece(token)
if piece == "" {
return "", fmt.Errorf("failed to convert token %d to piece", token)
}
text += piece
}
return text, nil
}
type ModelParams struct { type ModelParams struct {
NumGpuLayers int NumGpuLayers int
MainGpu int MainGpu int

View File

@ -15,7 +15,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -503,9 +502,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if seq.returnLogits { // New flag we need to add to Sequence struct if seq.returnLogits { // New flag we need to add to Sequence struct
logits := s.lc.GetLogits() logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits)) seq.logits = make([]float32, len(logits))
slog.Info("copying logits")
copy(seq.logits, logits) copy(seq.logits, logits)
slog.Info("copying logits success")
} }
// Then sample token // Then sample token
@ -608,7 +605,7 @@ type CompletionRequest struct {
Images []ImageData `json:"image_data"` Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"` Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"` CachePrompt bool `json:"cache_prompt"`
ReturnLogits bool `json:"return_logits"` ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false
Options Options
} }
@ -729,7 +726,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
slog.Info("content", "content", content.Content) // slog.Info("content", "content", content.Content)
if err := json.NewEncoder(w).Encode(&content); err != nil { if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
@ -1040,50 +1037,50 @@ func Execute(args []string) error {
return nil return nil
} }
// Helper function to get top K logits and convert to log probabilities // // Helper function to get top K logits and convert to log probabilities
func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs { // func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
if k <= 0 { // if k <= 0 {
return nil // return nil
} // }
// Convert logits to probabilities using softmax // // Convert logits to probabilities using softmax
probs := softmax(logits) // probs := softmax(logits)
// Create slice of index/probability pairs // // Create slice of index/probability pairs
pairs := make([]struct { // pairs := make([]struct {
token int // token int
prob float32 // prob float32
}, len(probs)) // }, len(probs))
for i, p := range probs { // for i, p := range probs {
pairs[i] = struct { // pairs[i] = struct {
token int // token int
prob float32 // prob float32
}{i, p} // }{i, p}
} // }
// Sort by probability (descending) // // Sort by probability (descending)
sort.Slice(pairs, func(i, j int) bool { // sort.Slice(pairs, func(i, j int) bool {
return pairs[i].prob > pairs[j].prob // return pairs[i].prob > pairs[j].prob
}) // })
// Take top K // // Take top K
k = min(k, len(pairs)) // k = min(k, len(pairs))
result := make([]api.LogProbs, k) // result := make([]api.LogProbs, k)
for i := 0; i < k; i++ { // for i := 0; i < k; i++ {
result[i] = api.LogProbs{ // result[i] = api.LogProbs{
TopLogprobs: []api.TokenLogprob{ // TopLogprobs: []api.TokenLogprob{
{ // {
Token: model.TokenToPiece(pairs[i].token), // Token: model.TokenToPiece(pairs[i].token),
Logprob: float32(math.Log(float64(pairs[i].prob))), // Logprob: float32(math.Log(float64(pairs[i].prob))),
}, // },
}, // },
} // }
} // }
return result // return result
} // }
// Helper function to compute softmax // Helper function to compute softmax
func softmax(logits []float32) []float32 { func softmax(logits []float32) []float32 {

View File

@ -19,6 +19,7 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"sort"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -299,7 +300,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
ReturnLogits: req.ReturnLogits, ReturnLogits: false,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
@ -1554,23 +1555,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
ReturnLogits: true, ReturnLogits: true,
}, func(r llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: cr.Content},
Done: r.Done, Done: cr.Done,
DoneReason: r.DoneReason, DoneReason: cr.DoneReason,
Logits: r.Logits, Logits: []float32{},
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
if r.Done { topK := int(3)
logits := make([]float32, len(cr.Logits))
copy(logits, cr.Logits)
res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
if cr.Done {
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
@ -1586,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls: // Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream // If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent // This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content) sb.WriteString(cr.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {
@ -1599,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if r.Done { if cr.Done {
// Send any remaining content if no tool calls were detected // Send any remaining content if no tool calls were detected
if toolCallIndex == 0 { if toolCallIndex == 0 {
res.Message.Content = sb.String() res.Message.Content = sb.String()
@ -1649,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
// Calculate softmax denominator first (log sum exp trick for numerical stability)
maxLogit := float32(math.Inf(-1))
for _, logit := range logits {
if logit > maxLogit {
maxLogit = logit
}
}
var sumExp float32
for _, logit := range logits {
sumExp += float32(math.Exp(float64(logit - maxLogit)))
}
logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
// Calculate log probs and track top K
logProbs := make([]api.TokenLogprob, len(logits))
for i, logit := range logits {
text, err := s.Detokenize(ctx, []int{i})
if err != nil {
slog.Error("detokenize error for logprob", "error", err)
continue
}
logProbs[i] = api.TokenLogprob{
Text: text,
Logprob: logit - logSumExp,
}
}
// Sort by logprob descending and take top K
sort.Slice(logProbs, func(i, j int) bool {
return logProbs[i].Logprob > logProbs[j].Logprob
})
if len(logProbs) > topK {
logProbs = logProbs[:topK]
}
return logProbs
}
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
switch { switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired): case errors.Is(err, errCapabilities), errors.Is(err, errRequired):