From 5c0b6639692ba5e6d44a0c73a9b5c85dc670d4f2 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 13 Mar 2025 09:53:27 -0700 Subject: [PATCH] sample: separate softmax and temperature transforms (#9732) --- sample/samplers.go | 2 +- sample/transforms.go | 19 +++++-- sample/transforms_test.go | 102 +++++++++++++++++++++++++++++++------- 3 files changed, 98 insertions(+), 25 deletions(-) diff --git a/sample/samplers.go b/sample/samplers.go index 8b0de3f54..e302f9147 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -87,8 +87,8 @@ func (s *Sampler) sample(tokens []token) (token, error) { // topK also sorts the tokens in descending order of logits tokens = topK(tokens, s.topK) - // token logit values are updated to probabilities tokens = temperature(tokens, s.temperature) + tokens = softmax(tokens) tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) diff --git a/sample/transforms.go b/sample/transforms.go index b65917afd..a5efa704e 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -25,8 +25,18 @@ func (h *tokenHeap) Pop() any { return x } -// temperature applies scaling and softmax to the logits +// temperature applies scaling to the logits func temperature(ts []token, temp float32) []token { + // Ensure temperature clipping near 0 to avoid numerical instability + temp = max(temp, 1e-7) + for i := range ts { + ts[i].value = ts[i].value / temp + } + return ts +} + +// softmax applies normalization to the logits +func softmax(ts []token) []token { // Find max logit for numerical stability maxLogit := float32(math.Inf(-1)) for _, t := range ts { @@ -35,15 +45,14 @@ func temperature(ts []token, temp float32) []token { } } - // Apply temperature and compute exp(x - max) - temp = max(temp, 1e-7) + // Compute exp(x - max) var sum float32 for i, v := range ts { - ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp))) + ts[i].value = float32(math.Exp(float64(v.value - maxLogit))) sum += ts[i].value } - // Normalize + // exp(x - max) / sum(exp(x - max)) for i := range ts { ts[i].value /= sum } diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 8f0a58b60..4880dd8f4 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -32,27 +32,83 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) { } } -func TestTemperatureAndSoftmax(t *testing.T) { - input := []float32{1, 4, -2, 0} +func TestTemperature(t *testing.T) { + input := []float32{1.0, 4.0, -2.0, 0.0} got := temperature(toTokens(input), 0.5) + want := []float32{2.0, 8.0, -4.0, 0.0} + compareLogits(t, "temperature(0.5)", want, got) - // Check probabilities sum to 1 - var sum float32 - for _, token := range got { - sum += token.value - } - if math.Abs(float64(sum-1.0)) > 1e-6 { - t.Errorf("probabilities don't sum to 1: got %f", sum) + got = temperature(toTokens(input), 1.0) + want = []float32{1.0, 4.0, -2.0, 0.0} + compareLogits(t, "temperature(1)", want, got) + + got = temperature(toTokens(input), 0.0) + want = []float32{1e7, 4e7, -2e7, 0.0} + compareLogits(t, "temperature(0)", want, got) +} + +func TestSoftmax(t *testing.T) { + tests := []struct { + name string + input []float32 + expected []float32 + }{ + { + name: "correctness softmax", + input: []float32{1, -2, 3, 0}, + expected: []float32{0.113550, 0.005653, 0.839024, 0.041773}, + }, + { + name: "normal distribution", + input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}, + }, + { + name: "single value", + input: []float32{1.0}, + }, + { + name: "identical values", + input: []float32{0.9, 0.9, 0.9}, + }, + { + name: "large values", + input: []float32{1000.0, 2000.0, 3000.0}, + }, + { + name: "small values", + input: []float32{1e-6, 2e-6, 3e-6}, + }, + { + name: "negative values", + input: []float32{-1.0, -2.0, -3.0}, + }, + { + name: "mixed values", + input: []float32{-100.0, 0.0, 100.0}, + }, } - got = temperature(toTokens(input), 1) - // Check probabilities sum to 1 - sum = 0.0 - for _, token := range got { - sum += token.value - } - if math.Abs(float64(sum-1.0)) > 1e-6 { - t.Errorf("probabilities don't sum to 1: got %f", sum) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := softmax(toTokens(tt.input)) + + if tt.expected != nil { + compareLogits(t, tt.name, tt.expected, got) + return + } + + // Check probabilities sum to 1 + var sum float32 + for _, token := range got { + sum += token.value + if token.value < 0 || token.value > 1 { + t.Errorf("probability out of range [0,1]: got %f", token.value) + } + } + if math.Abs(float64(sum-1.0)) > 1e-6 { + t.Errorf("probabilities don't sum to 1: got %f", sum) + } + }) } } @@ -97,7 +153,7 @@ func TestTopP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax to get probabilities - tokens = temperature(tokens, 1) + tokens = softmax(tokens) tokens = topK(tokens, 20) // Then apply topP @@ -115,7 +171,7 @@ func TestMinP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax - tokens = temperature(tokens, 1) + tokens = softmax(tokens) // Then apply minP got := minP(tokens, 0.2) @@ -163,6 +219,14 @@ func BenchmarkTransforms(b *testing.B) { } }) + b.Run("Softmax", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + softmax(tokensCopy) + } + }) + b.Run("TopK", func(b *testing.B) { b.ResetTimer() for b.Loop() {