From 7d16ec8fe8cc8726b524be952eec84578501dcac Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 12 Feb 2025 16:36:03 -0800 Subject: [PATCH] print logprobs --- llama/llama.go | 27 +++++++++++++++- llama/runner/runner.go | 71 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index a20f23578..b71aa51ac 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -50,7 +50,7 @@ import ( _ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava" _ "github.com/ollama/ollama/llama/llama.cpp/src" - "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) func BackendInit() { @@ -220,6 +220,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return embeddings } +// GetLogits returns the logits from the last decode operation. +// The returned slice has length equal to the vocabulary size. +func (c *Context) GetLogits() []float32 { + logits := unsafe.Pointer(C.llama_get_logits(c.c)) + if logits == nil { + return nil + } + + // Get the number of vocabulary tokens to determine array size + vocabSize := c.Model().NumVocab() + return unsafe.Slice((*float32)(logits), vocabSize) +} + +func (m *Model) Detokenize(tokens []int) (string, error) { + var text string + for _, token := range tokens { + piece := m.TokenToPiece(token) + if piece == "" { + return "", fmt.Errorf("failed to convert token %d to piece", token) + } + text += piece + } + return text, nil +} + type ModelParams struct { NumGpuLayers int MainGpu int diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 60ae88dac..8e9698439 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -8,12 +8,14 @@ import ( "fmt" "log" "log/slog" + "math" "net" "net/http" "os" "path/filepath" "regexp" "runtime" + "sort" "strconv" "strings" "sync" @@ -83,6 +85,8 @@ type Sequence struct { doneReason string + logits []float32 + // Metrics startProcessingTime time.Time startGenerationTime time.Time @@ -274,6 +278,9 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { + if len(seq.pendingResponses) == 0 { + return true + } joined := strings.Join(seq.pendingResponses, "") seq.pendingResponses = []string{} @@ -287,8 +294,11 @@ func flushPending(seq *Sequence) bool { joined = joined[:len(joined)-1] } - if len(joined) == 0 { - return true + // Add logits if requested and available + wantLogits := true + if wantLogits && seq.logits != nil { + // resp.Logits = seq.logits + seq.logits = nil } select { @@ -350,6 +360,57 @@ func (s *Server) run(ctx context.Context) { } } +// TokenData represents probability information for a token +type TokenData struct { + TokenID int + Logit float32 + Prob float32 + LogProb float32 +} + +// getTokenProbabilities returns sorted token probabilities for a specific token index +func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData { + // Get logits for the specific token index + logits := s.lc.GetLogits() + seq.logits = make([]float32, len(logits)) + copy(seq.logits, logits) + + vocabSize := s.model.NumVocab() + probs := make([]TokenData, vocabSize) + + // Initialize token data with logits + for i := 0; i < vocabSize; i++ { + probs[i] = TokenData{ + TokenID: i, + Logit: logits[i], + } + } + + // Sort tokens by logits in descending order + sort.Slice(probs, func(i, j int) bool { + return probs[i].Logit > probs[j].Logit + }) + + // Apply softmax + maxLogit := probs[0].Logit + var sum float32 = 0.0 + + for i := range probs { + p := float32(math.Exp(float64(probs[i].Logit - maxLogit))) + probs[i].Prob = p + sum += p + } + + // Normalize probabilities and calculate log probs + for i := range probs { + prob := probs[i].Prob / sum + probs[i].Prob = prob + probs[i].LogProb = float32(math.Log(float64(prob))) + } + + return probs +} + // TODO (jmorganca): processBatch should be simplified, removing: // * sampling // * stop token checking @@ -483,6 +544,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.numPredicted++ + // TODO: only do this when flag specified + probs := s.getTokenProbabilities(seq) + for i := range 10 { + slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID)) + } + // if it's an end of sequence token, break if s.model.TokenIsEog(token) { // TODO (jmorganca): we should send this back