From 9622b928b425246a18949363bc5698290384598d Mon Sep 17 00:00:00 2001 From: jmorganca Date: Wed, 12 Mar 2025 18:28:59 +0100 Subject: [PATCH] extras --- sample/samplers.go | 73 +++++++++++++++++++------------------- sample/transforms.go | 83 ++++++++++++++++++++++++++++++-------------- 2 files changed, 93 insertions(+), 63 deletions(-) diff --git a/sample/samplers.go b/sample/samplers.go index aea99b3f2..d794228cb 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -1,11 +1,10 @@ package sample import ( - "errors" "math" - "math/rand/v2" - "slices" + "math/rand" "sync" + "time" "github.com/ollama/ollama/llama" ) @@ -90,53 +89,53 @@ func (s *Sampler) sample(tokens []token) (token, error) { sortLogits(tokens) } - // token logit values are updated to probabilities - tokens = temperature(tokens, s.temperature) - tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) - // TODO: this should fall back to greedy sampling - // or topP, topK values etc should be such that - // there are always tokens to sample from - if len(tokens) == 0 { - return token{}, errors.New("no tokens to sample from") - } + // token logit values are updated to probabilities + temperature(tokens, s.temperature) + softmax(tokens) + return tokens[dist(tokens, s.rng.Int63())], nil - var r float32 - if s.rng != nil { - r = s.rng.Float32() - } else { - r = rand.Float32() - } + // // TODO: this should fall back to greedy sampling + // // or topP, topK values etc should be such that + // // there are always tokens to sample from + // if len(tokens) == 0 { + // return token{}, errors.New("no tokens to sample from") + // } - // Calculate cumulative sum of probabilities - var sum float32 - for i := range tokens { - sum += tokens[i].value - tokens[i].value = sum - } - r *= tokens[len(tokens)-1].value + // var r float32 + // if s.rng != nil { + // r = s.rng.Float32() + // } else { + // r = rand.Float32() + // } - idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { - if token.value < target { - return -1 - } - return 1 - }) + // // Calculate cumulative sum of probabilities + // var sum float32 + // for i := range tokens { + // sum += tokens[i].value + // tokens[i].value = sum + // } + // r *= tokens[len(tokens)-1].value - return tokens[idx], nil + // idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { + // if token.value < target { + // return -1 + // } + // return 1 + // }) + + // return tokens[idx], nil } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { var rng *rand.Rand if seed != -1 { - // PCG requires two parameters: sequence and stream - // Use original seed for sequence - sequence := uint64(seed) - // Use golden ratio hash to generate statistically independent seeds - rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) + rng = rand.New(rand.NewSource(int64(seed))) + } else { + rng = rand.New(rand.NewSource(time.Now().UnixNano())) } if temperature < 0.0 { temperature = 0.0 diff --git a/sample/transforms.go b/sample/transforms.go index 0d7797dae..f0a6a62c5 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -3,6 +3,7 @@ package sample import ( "container/heap" "math" + "math/rand" "slices" ) @@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any { return x } -// temperature applies scaling and softmax to the logits -func temperature(ts []token, temp float32) []token { - // Find max logit for numerical stability - maxLogit := float32(math.Inf(-1)) - for _, t := range ts { - if t.value > maxLogit { - maxLogit = t.value - } - } - - // Apply temperature and compute exp(x - max) - temp = max(temp, 1e-7) - var sum float32 - for i, v := range ts { - ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp))) - sum += ts[i].value - } - - // Normalize - for i := range ts { - ts[i].value /= sum - } - - return ts -} - // topK limits the number of tokens considered to the k highest logits func topK(ts []token, k int) []token { if k >= len(ts) { @@ -200,3 +175,59 @@ func sortLogits(ts []token) { partialSortLogits(ts, n) } + +func temperature(ts []token, temp float32) { + for i := range ts { + ts[i].value /= temp + } +} + +func softmax(ts []token) { + if len(ts) == 0 { + return + } + + // Find max logit for numerical stability + maxLogit := ts[0].value + for _, t := range ts { + if t.value > maxLogit { + maxLogit = t.value + } + } + + // Compute exp(logit - maxLogit) and sum them + var sumExp float32 + for i, t := range ts { + expVal := float32(math.Exp(float64(t.value - maxLogit))) + ts[i].value = expVal + sumExp += expVal + } + + // Normalize probabilities + for i := range ts { + ts[i].value /= sumExp + } +} + +// applyDist selects a token based on probabilities and seed +func dist(ts []token, seed int64) int { + rng := rand.New(rand.NewSource(seed)) + + cdf := make([]float32, len(ts)) + var cumSum float32 + for i, t := range ts { + cumSum += t.value + cdf[i] = cumSum + } + + r := rng.Float32() * cumSum + + // Select token based on CDF + for i, probSum := range cdf { + if r < probSum { + return i + } + } + + return len(ts) - 1 +}