...
This commit is contained in:
parent
fdbb0b5cfe
commit
b88489a87e
@ -378,13 +378,7 @@ type TokenProbs struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// probs returns sorted token probabilities for a specific token index
|
// probs returns sorted token probabilities for a specific token index
|
||||||
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
func probs(logits []float32, vocabSize int) []TokenProbs {
|
||||||
// Get logits for the specific token index
|
|
||||||
logits := s.lc.GetLogits()
|
|
||||||
seq.logits = make([]float32, len(logits))
|
|
||||||
copy(seq.logits, logits)
|
|
||||||
|
|
||||||
vocabSize := s.model.NumVocab()
|
|
||||||
probs := make([]TokenProbs, vocabSize)
|
probs := make([]TokenProbs, vocabSize)
|
||||||
|
|
||||||
// Initialize token data with logits
|
// Initialize token data with logits
|
||||||
@ -420,6 +414,17 @@ func (s *Server) probs(seq *Sequence) []TokenProbs {
|
|||||||
return probs
|
return probs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// probs returns sorted token probabilities for a specific token index
|
||||||
|
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
||||||
|
// Get logits for the specific token index
|
||||||
|
logits := s.lc.GetLogits()
|
||||||
|
seq.logits = make([]float32, len(logits))
|
||||||
|
copy(seq.logits, logits)
|
||||||
|
|
||||||
|
vocabSize := s.model.NumVocab()
|
||||||
|
return probs(logits, vocabSize)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): processBatch should be simplified, removing:
|
// TODO (jmorganca): processBatch should be simplified, removing:
|
||||||
// * sampling
|
// * sampling
|
||||||
// * stop token checking
|
// * stop token checking
|
||||||
|
58
llama/runner/runner_test.go
Normal file
58
llama/runner/runner_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProbs(t *testing.T) {
|
||||||
|
// Input test data
|
||||||
|
logits := []float32{1.0, 2.0, 0.5, -1.0}
|
||||||
|
vocabSize := 4
|
||||||
|
want := []TokenProbs{
|
||||||
|
{TokenID: 1, Logit: 2.0}, // Highest logit
|
||||||
|
{TokenID: 0, Logit: 1.0}, // Second highest
|
||||||
|
{TokenID: 2, Logit: 0.5}, // Third
|
||||||
|
{TokenID: 3, Logit: -1.0}, // Lowest
|
||||||
|
}
|
||||||
|
|
||||||
|
got := probs(logits, vocabSize)
|
||||||
|
|
||||||
|
// Test 1: Check sorting order
|
||||||
|
for i := 0; i < len(got)-1; i++ {
|
||||||
|
if got[i].Logit < got[i+1].Logit {
|
||||||
|
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
|
||||||
|
i, got[i].Logit, i+1, got[i+1].Logit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Check probability normalization
|
||||||
|
var sum float32
|
||||||
|
for _, p := range got {
|
||||||
|
sum += p.Prob
|
||||||
|
}
|
||||||
|
if math.Abs(float64(sum-1.0)) > 1e-6 {
|
||||||
|
t.Errorf("probabilities do not sum to 1: got %v", sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Check token IDs match expected order
|
||||||
|
for i, want := range want {
|
||||||
|
if got[i].TokenID != want.TokenID {
|
||||||
|
t.Errorf("wrong token ID at position %d: got %d, want %d",
|
||||||
|
i, got[i].TokenID, want.TokenID)
|
||||||
|
}
|
||||||
|
if got[i].Logit != want.Logit {
|
||||||
|
t.Errorf("wrong logit at position %d: got %f, want %f",
|
||||||
|
i, got[i].Logit, want.Logit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Check log probs are correctly calculated
|
||||||
|
for i, p := range got {
|
||||||
|
expectedLogProb := float32(math.Log(float64(p.Prob)))
|
||||||
|
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
|
||||||
|
t.Errorf("wrong log prob at position %d: got %f, want %f",
|
||||||
|
i, p.LogProb, expectedLogProb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user