addressing comments + cleanup

This commit is contained in:
ParthSareen 2025-01-14 16:13:38 -08:00
parent 5e73f24e16
commit 5b19d4941a
2 changed files with 50 additions and 39 deletions

View File

@ -1,9 +1,10 @@
package sample package sample
import ( import (
"cmp"
"errors" "errors"
"math" "math"
"sort" "slices"
"gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat/sampleuv" "gonum.org/v1/gonum/stat/sampleuv"
@ -20,13 +21,12 @@ func (s Temperature) Sample(logits []float64) ([]float64, error) {
return nil, errors.New("temperature must be between 0 and 1") return nil, errors.New("temperature must be between 0 and 1")
} }
copiedLogits := append([]float64(nil), logits...) // greedy sampling
// Greedy sampling
if s == 0 { if s == 0 {
return []float64{floats.Max(copiedLogits)}, nil return []float64{floats.Max(logits)}, nil
} }
floats.Scale(1.0/float64(s), copiedLogits) floats.Scale(1.0/float64(s), logits)
return copiedLogits, nil return logits, nil
} }
type softmax struct{} type softmax struct{}
@ -69,8 +69,9 @@ func (k TopK) Sample(logits []float64) ([]float64, error) {
indices[i] = i indices[i] = i
} }
sort.Slice(indices, func(i, j int) bool { // sort in descending order
return logits[indices[i]] > logits[indices[j]] slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(logits[j], logits[i])
}) })
for _, idx := range indices[k:] { for _, idx := range indices[k:] {
@ -96,8 +97,10 @@ func (p TopP) Sample(logits []float64) ([]float64, error) {
for i := range indices { for i := range indices {
indices[i] = i indices[i] = i
} }
sort.Slice(indices, func(i, j int) bool {
return probs[indices[i]] > probs[indices[j]] // sort in descending order
slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(probs[j], probs[i])
}) })
cumSum := 0.0 cumSum := 0.0
@ -127,9 +130,9 @@ func (p MinP) Sample(logits []float64) ([]float64, error) {
copiedProbs := make([]float64, len(probs)) copiedProbs := make([]float64, len(probs))
copy(copiedProbs, probs) copy(copiedProbs, probs)
sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] }) slices.Sort(copiedProbs)
maxProb := floats.Max(probs) maxProb := copiedProbs[len(copiedProbs)-1]
probThreshold := float64(p) * maxProb probThreshold := float64(p) * maxProb
for i := range probs { for i := range probs {
@ -162,20 +165,23 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
return nil, errors.New("no valid tokens found") return nil, errors.New("no valid tokens found")
} }
// usually, a softmax is applied to sample from the logits
// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
w := sampleuv.NewWeighted(logitsCopy, nil) w := sampleuv.NewWeighted(logitsCopy, nil)
if v, ok := w.Take(); ok { if v, ok := w.Take(); ok {
// returns the token ID
return []float64{float64(indices[v])}, nil return []float64{float64(indices[v])}, nil
} }
return nil, errors.New("weighed sampler failed") return nil, errors.New("weighed sampler failed")
} }
func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) { func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
var err error var err error
for _, sampler := range samplers { for _, sampler := range samplers {
tokenID, err = sampler.Sample(tokenID) logits, err = sampler.Sample(logits)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return tokenID, nil return logits, nil
} }

View File

@ -3,9 +3,14 @@ package sample
import ( import (
"fmt" "fmt"
"math" "math"
"math/rand"
"os"
"runtime"
"slices" "slices"
"testing" "testing"
"runtime/trace"
"gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/floats"
) )
@ -14,9 +19,9 @@ func TestTemperature(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8} want := []float64{-6, -4, -2, 0, 2, 4, 8}
if !floats.Equal(logits, expectedlogits) { if !floats.Equal(logits, want) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) t.Fatalf("got: %v, want: %v", logits, want)
} }
// Only expect the max value returned // Only expect the max value returned
@ -24,9 +29,9 @@ func TestTemperature(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedlogits = []float64{4} want = []float64{4}
if !floats.Equal(logits, expectedlogits) { if !floats.Equal(logits, want) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) t.Fatalf("got: %v, want: %v", logits, want)
} }
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
@ -35,7 +40,7 @@ func TestTemperature(t *testing.T) {
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4}) probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -66,9 +71,9 @@ func TestTopP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
if !floats.Same(logits, expectedLogits) { if !floats.Same(logits, want) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) t.Fatalf("got: %v, want: %v", logits, want)
} }
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
@ -85,9 +90,9 @@ func TestMinP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
if !floats.Same(logits, expectedLogits) { if !floats.Same(logits, want) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) t.Fatalf("got: %v, want: %v", logits, want)
} }
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil { if err == nil {
@ -104,9 +109,9 @@ func TestWeighed(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedLogits := []float64{1} want := []float64{1}
if !floats.Equal(logits, expectedLogits) { if !floats.Equal(logits, want) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) t.Fatalf("got: %v, want: %v", logits, want)
} }
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()}) logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
if err == nil { if err == nil {
@ -116,36 +121,36 @@ func TestWeighed(t *testing.T) {
func TestSample(t *testing.T) { func TestSample(t *testing.T) {
input := []float64{1, 2, 3, 4} input := []float64{1, 2, 3, 4}
expectedOutput := []float64{1, 2, 3, 4} want := []float64{1, 2, 3, 4}
var callOrder []int var callOrder []int
mock1 := &mockSampler{ mock1 := &mockSampler{
id: 1, id: 1,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: expectedOutput, returnVals: want,
} }
mock2 := &mockSampler{ mock2 := &mockSampler{
id: 2, id: 2,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: expectedOutput, returnVals: want,
} }
mock3 := &mockSampler{ mock3 := &mockSampler{
id: 3, id: 3,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: expectedOutput, returnVals: want,
} }
result, err := Sample(input, mock1, mock2, mock3) got, err := Sample(input, mock1, mock2, mock3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(callOrder, []int{1, 2, 3}) { if !slices.Equal(callOrder, []int{1, 2, 3}) {
t.Errorf("Expected call order [1,2,3], got %v", callOrder) t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
} }
if !floats.Equal(result, expectedOutput) { if !floats.Equal(got, want) {
t.Errorf("Expected output %v, got %v", expectedOutput, result) t.Errorf("got %v, want %v", got, want)
} }
errMock := &mockSampler{ errMock := &mockSampler{