Compare commits
1 Commits
main
...
parth/samp
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a5d638dfe7 |
@ -1,11 +1,10 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand"
|
||||||
"slices"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
)
|
)
|
||||||
@ -87,53 +86,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
// topK also sorts the tokens in descending order of logits
|
// topK also sorts the tokens in descending order of logits
|
||||||
tokens = topK(tokens, s.topK)
|
tokens = topK(tokens, s.topK)
|
||||||
|
|
||||||
// token logit values are updated to probabilities
|
|
||||||
tokens = temperature(tokens, s.temperature)
|
|
||||||
|
|
||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
// token logit values are updated to probabilities
|
||||||
// or topP, topK values etc should be such that
|
temperature(tokens, s.temperature)
|
||||||
// there are always tokens to sample from
|
softmax(tokens)
|
||||||
if len(tokens) == 0 {
|
return tokens[dist(tokens, s.rng.Int63())], nil
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
// // TODO: this should fall back to greedy sampling
|
||||||
if s.rng != nil {
|
// // or topP, topK values etc should be such that
|
||||||
r = s.rng.Float32()
|
// // there are always tokens to sample from
|
||||||
} else {
|
// if len(tokens) == 0 {
|
||||||
r = rand.Float32()
|
// return token{}, errors.New("no tokens to sample from")
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Calculate cumulative sum of probabilities
|
// var r float32
|
||||||
var sum float32
|
// if s.rng != nil {
|
||||||
for i := range tokens {
|
// r = s.rng.Float32()
|
||||||
sum += tokens[i].value
|
// } else {
|
||||||
tokens[i].value = sum
|
// r = rand.Float32()
|
||||||
}
|
// }
|
||||||
r *= tokens[len(tokens)-1].value
|
|
||||||
|
|
||||||
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
// // Calculate cumulative sum of probabilities
|
||||||
if token.value < target {
|
// var sum float32
|
||||||
return -1
|
// for i := range tokens {
|
||||||
}
|
// sum += tokens[i].value
|
||||||
return 1
|
// tokens[i].value = sum
|
||||||
})
|
// }
|
||||||
|
// r *= tokens[len(tokens)-1].value
|
||||||
|
|
||||||
return tokens[idx], nil
|
// idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
||||||
|
// if token.value < target {
|
||||||
|
// return -1
|
||||||
|
// }
|
||||||
|
// return 1
|
||||||
|
// })
|
||||||
|
|
||||||
|
// return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
rng = rand.New(rand.NewSource(int64(seed)))
|
||||||
// Use original seed for sequence
|
} else {
|
||||||
sequence := uint64(seed)
|
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
// Use golden ratio hash to generate statistically independent seeds
|
|
||||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
|
||||||
}
|
}
|
||||||
if temperature < 0.0 {
|
if temperature < 0.0 {
|
||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
|
1
sample/testdata/logits.bin
vendored
Normal file
1
sample/testdata/logits.bin
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -3,6 +3,7 @@ package sample
|
|||||||
import (
|
import (
|
||||||
"container/heap"
|
"container/heap"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
"slices"
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// temperature applies scaling and softmax to the logits
|
|
||||||
func temperature(ts []token, temp float32) []token {
|
|
||||||
// Find max logit for numerical stability
|
|
||||||
maxLogit := float32(math.Inf(-1))
|
|
||||||
for _, t := range ts {
|
|
||||||
if t.value > maxLogit {
|
|
||||||
maxLogit = t.value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply temperature and compute exp(x - max)
|
|
||||||
temp = max(temp, 1e-7)
|
|
||||||
var sum float32
|
|
||||||
for i, v := range ts {
|
|
||||||
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
|
|
||||||
sum += ts[i].value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize
|
|
||||||
for i := range ts {
|
|
||||||
ts[i].value /= sum
|
|
||||||
}
|
|
||||||
|
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// topK limits the number of tokens considered to the k highest logits
|
// topK limits the number of tokens considered to the k highest logits
|
||||||
func topK(ts []token, k int) []token {
|
func topK(ts []token, k int) []token {
|
||||||
if k >= len(ts) || k <= 0 {
|
if k >= len(ts) || k <= 0 {
|
||||||
@ -134,3 +109,59 @@ func minP(ts []token, p float32) []token {
|
|||||||
ts = validTokens
|
ts = validTokens
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func temperature(ts []token, temp float32) {
|
||||||
|
for i := range ts {
|
||||||
|
ts[i].value /= temp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func softmax(ts []token) {
|
||||||
|
if len(ts) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find max logit for numerical stability
|
||||||
|
maxLogit := ts[0].value
|
||||||
|
for _, t := range ts {
|
||||||
|
if t.value > maxLogit {
|
||||||
|
maxLogit = t.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute exp(logit - maxLogit) and sum them
|
||||||
|
var sumExp float32
|
||||||
|
for i, t := range ts {
|
||||||
|
expVal := float32(math.Exp(float64(t.value - maxLogit)))
|
||||||
|
ts[i].value = expVal
|
||||||
|
sumExp += expVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize probabilities
|
||||||
|
for i := range ts {
|
||||||
|
ts[i].value /= sumExp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyDist selects a token based on probabilities and seed
|
||||||
|
func dist(ts []token, seed int64) int {
|
||||||
|
rng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
|
cdf := make([]float32, len(ts))
|
||||||
|
var cumSum float32
|
||||||
|
for i, t := range ts {
|
||||||
|
cumSum += t.value
|
||||||
|
cdf[i] = cumSum
|
||||||
|
}
|
||||||
|
|
||||||
|
r := rng.Float32() * cumSum
|
||||||
|
|
||||||
|
// Select token based on CDF
|
||||||
|
for i, probSum := range cdf {
|
||||||
|
if r < probSum {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(ts) - 1
|
||||||
|
}
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -143,6 +148,98 @@ func TestSortLogits(t *testing.T) {
|
|||||||
compareLogits(t, "sortLogits", want, tokens)
|
compareLogits(t, "sortLogits", want, tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
|
||||||
|
func TestSortLogitsWithRealData(t *testing.T) {
|
||||||
|
// This will be populated from testdata/logits.bin
|
||||||
|
// Format: 32-bit float array in binary format
|
||||||
|
logits, err := loadTestLogits(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Skipping real logit test: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := toTokens(logits)
|
||||||
|
sortLogits(tokens)
|
||||||
|
|
||||||
|
// Calculate n for verification
|
||||||
|
n := int(math.Sqrt(float64(len(tokens)))) + 1
|
||||||
|
if n > 1000 {
|
||||||
|
n = 1000
|
||||||
|
} else if n < 100 {
|
||||||
|
n = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
|
||||||
|
|
||||||
|
// Only verify the top n elements are sorted (which is what we guarantee)
|
||||||
|
// This is much faster than checking the entire array
|
||||||
|
topN := tokens[:n]
|
||||||
|
for i := 1; i < len(topN); i++ {
|
||||||
|
if topN[i].value > topN[i-1].value {
|
||||||
|
t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
|
||||||
|
n, i, topN[i].value, topN[i-1].value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we didn't lose any high value tokens by checking that
|
||||||
|
// all tokens after position n are <= the nth token
|
||||||
|
// Do this in chunks to avoid timeouts on large arrays
|
||||||
|
nthValue := tokens[n-1].value
|
||||||
|
const chunkSize = 1000
|
||||||
|
|
||||||
|
for start := n; start < len(tokens); start += chunkSize {
|
||||||
|
end := min(start+chunkSize, len(tokens))
|
||||||
|
for i := start; i < end; i++ {
|
||||||
|
if tokens[i].value > nthValue {
|
||||||
|
t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
|
||||||
|
n, i, tokens[i].value, nthValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTestLogits loads logit test data from testdata/logits.bin
|
||||||
|
func loadTestLogits(t *testing.T) ([]float32, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, currFile, _, ok := runtime.Caller(0)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not determine test file path")
|
||||||
|
}
|
||||||
|
testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
|
||||||
|
|
||||||
|
file, err := os.Open(testDataPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
numFloats := stat.Size() / 4 // each float32 is 4 bytes
|
||||||
|
if numFloats*4 != stat.Size() {
|
||||||
|
return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
logits := make([]float32, numFloats)
|
||||||
|
for i := range logits {
|
||||||
|
var val uint32
|
||||||
|
if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logits[i] = math.Float32frombits(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return nil, errors.New("logits.bin is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
// Generate random logits
|
// Generate random logits
|
||||||
tokens := make([]token, 1<<16)
|
tokens := make([]token, 1<<16)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user