diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 1e35fc19d..815ce6191 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -378,13 +378,7 @@ type TokenProbs struct { } // probs returns sorted token probabilities for a specific token index -func (s *Server) probs(seq *Sequence) []TokenProbs { - // Get logits for the specific token index - logits := s.lc.GetLogits() - seq.logits = make([]float32, len(logits)) - copy(seq.logits, logits) - - vocabSize := s.model.NumVocab() +func probs(logits []float32, vocabSize int) []TokenProbs { probs := make([]TokenProbs, vocabSize) // Initialize token data with logits @@ -420,6 +414,17 @@ func (s *Server) probs(seq *Sequence) []TokenProbs { return probs } +// probs returns sorted token probabilities for a specific token index +func (s *Server) probs(seq *Sequence) []TokenProbs { + // Get logits for the specific token index + logits := s.lc.GetLogits() + seq.logits = make([]float32, len(logits)) + copy(seq.logits, logits) + + vocabSize := s.model.NumVocab() + return probs(logits, vocabSize) +} + // TODO (jmorganca): processBatch should be simplified, removing: // * sampling // * stop token checking diff --git a/llama/runner/runner_test.go b/llama/runner/runner_test.go new file mode 100644 index 000000000..bb4a6da9e --- /dev/null +++ b/llama/runner/runner_test.go @@ -0,0 +1,58 @@ +package runner + +import ( + "math" + "testing" +) + +func TestProbs(t *testing.T) { + // Input test data + logits := []float32{1.0, 2.0, 0.5, -1.0} + vocabSize := 4 + want := []TokenProbs{ + {TokenID: 1, Logit: 2.0}, // Highest logit + {TokenID: 0, Logit: 1.0}, // Second highest + {TokenID: 2, Logit: 0.5}, // Third + {TokenID: 3, Logit: -1.0}, // Lowest + } + + got := probs(logits, vocabSize) + + // Test 1: Check sorting order + for i := 0; i < len(got)-1; i++ { + if got[i].Logit < got[i+1].Logit { + t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)", + i, got[i].Logit, i+1, got[i+1].Logit) + } + } + + // Test 2: Check probability normalization + var sum float32 + for _, p := range got { + sum += p.Prob + } + if math.Abs(float64(sum-1.0)) > 1e-6 { + t.Errorf("probabilities do not sum to 1: got %v", sum) + } + + // Test 3: Check token IDs match expected order + for i, want := range want { + if got[i].TokenID != want.TokenID { + t.Errorf("wrong token ID at position %d: got %d, want %d", + i, got[i].TokenID, want.TokenID) + } + if got[i].Logit != want.Logit { + t.Errorf("wrong logit at position %d: got %f, want %f", + i, got[i].Logit, want.Logit) + } + } + + // Test 4: Check log probs are correctly calculated + for i, p := range got { + expectedLogProb := float32(math.Log(float64(p.Prob))) + if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 { + t.Errorf("wrong log prob at position %d: got %f, want %f", + i, p.LogProb, expectedLogProb) + } + } +}