extras
This commit is contained in:
parent
7fa6ea0da7
commit
9622b928b4
@ -1,11 +1,10 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
@ -90,53 +89,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
sortLogits(tokens)
|
||||
}
|
||||
|
||||
// token logit values are updated to probabilities
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
// TODO: this should fall back to greedy sampling
|
||||
// or topP, topK values etc should be such that
|
||||
// there are always tokens to sample from
|
||||
if len(tokens) == 0 {
|
||||
return token{}, errors.New("no tokens to sample from")
|
||||
}
|
||||
// token logit values are updated to probabilities
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
return tokens[dist(tokens, s.rng.Int63())], nil
|
||||
|
||||
var r float32
|
||||
if s.rng != nil {
|
||||
r = s.rng.Float32()
|
||||
} else {
|
||||
r = rand.Float32()
|
||||
}
|
||||
// // TODO: this should fall back to greedy sampling
|
||||
// // or topP, topK values etc should be such that
|
||||
// // there are always tokens to sample from
|
||||
// if len(tokens) == 0 {
|
||||
// return token{}, errors.New("no tokens to sample from")
|
||||
// }
|
||||
|
||||
// Calculate cumulative sum of probabilities
|
||||
var sum float32
|
||||
for i := range tokens {
|
||||
sum += tokens[i].value
|
||||
tokens[i].value = sum
|
||||
}
|
||||
r *= tokens[len(tokens)-1].value
|
||||
// var r float32
|
||||
// if s.rng != nil {
|
||||
// r = s.rng.Float32()
|
||||
// } else {
|
||||
// r = rand.Float32()
|
||||
// }
|
||||
|
||||
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
||||
if token.value < target {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
// // Calculate cumulative sum of probabilities
|
||||
// var sum float32
|
||||
// for i := range tokens {
|
||||
// sum += tokens[i].value
|
||||
// tokens[i].value = sum
|
||||
// }
|
||||
// r *= tokens[len(tokens)-1].value
|
||||
|
||||
return tokens[idx], nil
|
||||
// idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
||||
// if token.value < target {
|
||||
// return -1
|
||||
// }
|
||||
// return 1
|
||||
// })
|
||||
|
||||
// return tokens[idx], nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
// Use original seed for sequence
|
||||
sequence := uint64(seed)
|
||||
// Use golden ratio hash to generate statistically independent seeds
|
||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||
rng = rand.New(rand.NewSource(int64(seed)))
|
||||
} else {
|
||||
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
if temperature < 0.0 {
|
||||
temperature = 0.0
|
||||
|
@ -3,6 +3,7 @@ package sample
|
||||
import (
|
||||
"container/heap"
|
||||
"math"
|
||||
"math/rand"
|
||||
"slices"
|
||||
)
|
||||
|
||||
@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
|
||||
return x
|
||||
}
|
||||
|
||||
// temperature applies scaling and softmax to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
if t.value > maxLogit {
|
||||
maxLogit = t.value
|
||||
}
|
||||
}
|
||||
|
||||
// Apply temperature and compute exp(x - max)
|
||||
temp = max(temp, 1e-7)
|
||||
var sum float32
|
||||
for i, v := range ts {
|
||||
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
|
||||
sum += ts[i].value
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []token, k int) []token {
|
||||
if k >= len(ts) {
|
||||
@ -200,3 +175,59 @@ func sortLogits(ts []token) {
|
||||
|
||||
partialSortLogits(ts, n)
|
||||
}
|
||||
|
||||
func temperature(ts []token, temp float32) {
|
||||
for i := range ts {
|
||||
ts[i].value /= temp
|
||||
}
|
||||
}
|
||||
|
||||
func softmax(ts []token) {
|
||||
if len(ts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := ts[0].value
|
||||
for _, t := range ts {
|
||||
if t.value > maxLogit {
|
||||
maxLogit = t.value
|
||||
}
|
||||
}
|
||||
|
||||
// Compute exp(logit - maxLogit) and sum them
|
||||
var sumExp float32
|
||||
for i, t := range ts {
|
||||
expVal := float32(math.Exp(float64(t.value - maxLogit)))
|
||||
ts[i].value = expVal
|
||||
sumExp += expVal
|
||||
}
|
||||
|
||||
// Normalize probabilities
|
||||
for i := range ts {
|
||||
ts[i].value /= sumExp
|
||||
}
|
||||
}
|
||||
|
||||
// applyDist selects a token based on probabilities and seed
|
||||
func dist(ts []token, seed int64) int {
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
cdf := make([]float32, len(ts))
|
||||
var cumSum float32
|
||||
for i, t := range ts {
|
||||
cumSum += t.value
|
||||
cdf[i] = cumSum
|
||||
}
|
||||
|
||||
r := rng.Float32() * cumSum
|
||||
|
||||
// Select token based on CDF
|
||||
for i, probSum := range cdf {
|
||||
if r < probSum {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return len(ts) - 1
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user