Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
a5d638dfe7 extras 2025-03-12 16:12:29 -04:00
4 changed files with 191 additions and 63 deletions

View File

@ -1,11 +1,10 @@
package sample package sample
import ( import (
"errors"
"math" "math"
"math/rand/v2" "math/rand"
"slices"
"sync" "sync"
"time"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
@ -87,53 +86,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits // topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK) tokens = topK(tokens, s.topK)
// token logit values are updated to probabilities
tokens = temperature(tokens, s.temperature)
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling // token logit values are updated to probabilities
// or topP, topK values etc should be such that temperature(tokens, s.temperature)
// there are always tokens to sample from softmax(tokens)
if len(tokens) == 0 { return tokens[dist(tokens, s.rng.Int63())], nil
return token{}, errors.New("no tokens to sample from")
}
var r float32 // // TODO: this should fall back to greedy sampling
if s.rng != nil { // // or topP, topK values etc should be such that
r = s.rng.Float32() // // there are always tokens to sample from
} else { // if len(tokens) == 0 {
r = rand.Float32() // return token{}, errors.New("no tokens to sample from")
} // }
// Calculate cumulative sum of probabilities // var r float32
var sum float32 // if s.rng != nil {
for i := range tokens { // r = s.rng.Float32()
sum += tokens[i].value // } else {
tokens[i].value = sum // r = rand.Float32()
} // }
r *= tokens[len(tokens)-1].value
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { // // Calculate cumulative sum of probabilities
if token.value < target { // var sum float32
return -1 // for i := range tokens {
} // sum += tokens[i].value
return 1 // tokens[i].value = sum
}) // }
// r *= tokens[len(tokens)-1].value
return tokens[idx], nil // idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
// if token.value < target {
// return -1
// }
// return 1
// })
// return tokens[idx], nil
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand var rng *rand.Rand
if seed != -1 { if seed != -1 {
// PCG requires two parameters: sequence and stream rng = rand.New(rand.NewSource(int64(seed)))
// Use original seed for sequence } else {
sequence := uint64(seed) rng = rand.New(rand.NewSource(time.Now().UnixNano()))
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
} }
if temperature < 0.0 { if temperature < 0.0 {
temperature = 0.0 temperature = 0.0

1
sample/testdata/logits.bin vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -3,6 +3,7 @@ package sample
import ( import (
"container/heap" "container/heap"
"math" "math"
"math/rand"
"slices" "slices"
) )
@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
return x return x
} }
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
sum += ts[i].value
}
// Normalize
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits // topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token { func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 { if k >= len(ts) || k <= 0 {
@ -134,3 +109,59 @@ func minP(ts []token, p float32) []token {
ts = validTokens ts = validTokens
return ts return ts
} }
func temperature(ts []token, temp float32) {
for i := range ts {
ts[i].value /= temp
}
}
func softmax(ts []token) {
if len(ts) == 0 {
return
}
// Find max logit for numerical stability
maxLogit := ts[0].value
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(logit - maxLogit) and sum them
var sumExp float32
for i, t := range ts {
expVal := float32(math.Exp(float64(t.value - maxLogit)))
ts[i].value = expVal
sumExp += expVal
}
// Normalize probabilities
for i := range ts {
ts[i].value /= sumExp
}
}
// applyDist selects a token based on probabilities and seed
func dist(ts []token, seed int64) int {
rng := rand.New(rand.NewSource(seed))
cdf := make([]float32, len(ts))
var cumSum float32
for i, t := range ts {
cumSum += t.value
cdf[i] = cumSum
}
r := rng.Float32() * cumSum
// Select token based on CDF
for i, probSum := range cdf {
if r < probSum {
return i
}
}
return len(ts) - 1
}

View File

@ -1,8 +1,13 @@
package sample package sample
import ( import (
"encoding/binary"
"errors"
"math" "math"
"math/rand/v2" "math/rand/v2"
"os"
"path/filepath"
"runtime"
"testing" "testing"
) )
@ -143,6 +148,98 @@ func TestSortLogits(t *testing.T) {
compareLogits(t, "sortLogits", want, tokens) compareLogits(t, "sortLogits", want, tokens)
} }
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func TestSortLogitsWithRealData(t *testing.T) {
// This will be populated from testdata/logits.bin
// Format: 32-bit float array in binary format
logits, err := loadTestLogits(t)
if err != nil {
t.Skipf("Skipping real logit test: %v", err)
return
}
tokens := toTokens(logits)
sortLogits(tokens)
// Calculate n for verification
n := int(math.Sqrt(float64(len(tokens)))) + 1
if n > 1000 {
n = 1000
} else if n < 100 {
n = 100
}
t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
// Only verify the top n elements are sorted (which is what we guarantee)
// This is much faster than checking the entire array
topN := tokens[:n]
for i := 1; i < len(topN); i++ {
if topN[i].value > topN[i-1].value {
t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
n, i, topN[i].value, topN[i-1].value)
}
}
// Verify we didn't lose any high value tokens by checking that
// all tokens after position n are <= the nth token
// Do this in chunks to avoid timeouts on large arrays
nthValue := tokens[n-1].value
const chunkSize = 1000
for start := n; start < len(tokens); start += chunkSize {
end := min(start+chunkSize, len(tokens))
for i := start; i < end; i++ {
if tokens[i].value > nthValue {
t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
n, i, tokens[i].value, nthValue)
}
}
}
}
// loadTestLogits loads logit test data from testdata/logits.bin
func loadTestLogits(t *testing.T) ([]float32, error) {
t.Helper()
_, currFile, _, ok := runtime.Caller(0)
if !ok {
return nil, errors.New("could not determine test file path")
}
testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
file, err := os.Open(testDataPath)
if err != nil {
return nil, err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return nil, err
}
numFloats := stat.Size() / 4 // each float32 is 4 bytes
if numFloats*4 != stat.Size() {
return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
}
logits := make([]float32, numFloats)
for i := range logits {
var val uint32
if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
return nil, err
}
logits[i] = math.Float32frombits(val)
}
if len(logits) == 0 {
return nil, errors.New("logits.bin is empty")
}
return logits, nil
}
func BenchmarkTransforms(b *testing.B) { func BenchmarkTransforms(b *testing.B) {
// Generate random logits // Generate random logits
tokens := make([]token, 1<<16) tokens := make([]token, 1<<16)