From 5e73f24e16a588addda4d98ee095c09383b94c74 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Fri, 10 Jan 2025 17:49:39 -0800 Subject: [PATCH] sampling package --- sample/sample.go | 149 ++++++++++++++++++++++++++++++----- sample/sample_test.go | 175 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 303 insertions(+), 21 deletions(-) create mode 100644 sample/sample_test.go diff --git a/sample/sample.go b/sample/sample.go index 44c08caed..cc816c7da 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -1,7 +1,9 @@ package sample import ( - "slices" + "errors" + "math" + "sort" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/stat/sampleuv" @@ -13,9 +15,18 @@ type Sampler interface { type Temperature float64 -func (s Temperature) Sample(t []float64) ([]float64, error) { - floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t))) - return t, nil +func (s Temperature) Sample(logits []float64) ([]float64, error) { + if s < 0 || s > 1 { + return nil, errors.New("temperature must be between 0 and 1") + } + + copiedLogits := append([]float64(nil), logits...) + // Greedy sampling + if s == 0 { + return []float64{floats.Max(copiedLogits)}, nil + } + floats.Scale(1.0/float64(s), copiedLogits) + return copiedLogits, nil } type softmax struct{} @@ -24,26 +35,110 @@ func Softmax() Sampler { return softmax{} } -func (softmax) Sample(t []float64) ([]float64, error) { - return t, nil +func (softmax) Sample(logits []float64) ([]float64, error) { + return computeSoftmax(logits) +} + +func computeSoftmax(logits []float64) ([]float64, error) { + copiedLogits := make([]float64, len(logits)) + copy(copiedLogits, logits) + for i := range copiedLogits { + copiedLogits[i] = math.Exp(copiedLogits[i]) + } + + floatSum := floats.Sum(copiedLogits) + if floatSum == 0 { + return nil, errors.New("no valid tokens found") + } + floats.Scale(1.0/floatSum, copiedLogits) + return copiedLogits, nil } type TopK int -func (s TopK) Sample(t []float64) ([]float64, error) { - return t, nil +func (k TopK) Sample(logits []float64) ([]float64, error) { + if k <= 0 { + return nil, errors.New("k must be positive") + } + if int(k) >= len(logits) { + return logits, nil + } + + indices := make([]int, len(logits)) + for i := range indices { + indices[i] = i + } + + sort.Slice(indices, func(i, j int) bool { + return logits[indices[i]] > logits[indices[j]] + }) + + for _, idx := range indices[k:] { + logits[idx] = math.NaN() + } + + return logits, nil } type TopP float32 -func (s TopP) Sample(t []float64) ([]float64, error) { - return t, nil +func (p TopP) Sample(logits []float64) ([]float64, error) { + if p <= 0 || p >= 1 { + return nil, errors.New("p must be between 0 and 1") + } + + probs, err := computeSoftmax(logits) + if err != nil { + return nil, err + } + + indices := make([]int, len(probs)) + for i := range indices { + indices[i] = i + } + sort.Slice(indices, func(i, j int) bool { + return probs[indices[i]] > probs[indices[j]] + }) + + cumSum := 0.0 + for i, idx := range indices { + cumSum += probs[idx] + if cumSum > float64(p) { + for _, idx := range indices[i+1:] { + logits[idx] = math.NaN() + } + break + } + } + return logits, nil } type MinP float32 -func (s MinP) Sample(t []float64) ([]float64, error) { - return t, nil +func (p MinP) Sample(logits []float64) ([]float64, error) { + if p <= 0 || p >= 1 { + return nil, errors.New("p must be between 0 and 1") + } + + probs, err := computeSoftmax(logits) + if err != nil { + return nil, err + } + copiedProbs := make([]float64, len(probs)) + copy(copiedProbs, probs) + + sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] }) + + maxProb := floats.Max(probs) + probThreshold := float64(p) * maxProb + + for i := range probs { + if probs[i] < probThreshold { + logits[i] = math.NaN() + } + } + + return logits, nil } type weighed struct{} @@ -52,23 +147,35 @@ func Weighed() Sampler { return weighed{} } -func (s weighed) Sample(t []float64) ([]float64, error) { - w := sampleuv.NewWeighted(t, nil) - if v, ok := w.Take(); ok { - return []float64{float64(v)}, nil +func (s weighed) Sample(logits []float64) ([]float64, error) { + logitsCopy := make([]float64, 0, len(logits)) + indices := make([]int, 0, len(logits)) + // the uv sampler does not support NaN values + for i, logit := range logits { + if !math.IsNaN(logit) { + logitsCopy = append(logitsCopy, logit) + indices = append(indices, i) + } } - return t, nil + if len(logitsCopy) == 0 { + return nil, errors.New("no valid tokens found") + } + + w := sampleuv.NewWeighted(logitsCopy, nil) + if v, ok := w.Take(); ok { + return []float64{float64(indices[v])}, nil + } + return nil, errors.New("weighed sampler failed") } -func Sample(floats []float64, samplers ...Sampler) ([]float64, error) { +func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) { var err error for _, sampler := range samplers { - floats, err = sampler.Sample(floats) + tokenID, err = sampler.Sample(tokenID) if err != nil { return nil, err } } - - return floats, nil + return tokenID, nil } diff --git a/sample/sample_test.go b/sample/sample_test.go new file mode 100644 index 000000000..314e5dd6d --- /dev/null +++ b/sample/sample_test.go @@ -0,0 +1,175 @@ +package sample + +import ( + "fmt" + "math" + "slices" + "testing" + + "gonum.org/v1/gonum/floats" +) + +func TestTemperature(t *testing.T) { + logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8} + if !floats.Equal(logits, expectedlogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + } + + // Only expect the max value returned + logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + expectedlogits = []float64{4} + if !floats.Equal(logits, expectedlogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + } + + if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { + t.Fatalf("expected error for temperature=-1, got %v", logits) + } +} + +func TestSoftmax(t *testing.T) { + probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + + expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} + if !floats.Equal(probs, expectedProbs) { + t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs) + } +} + +func TestTopK(t *testing.T) { + logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4} + if !floats.Same(logits, expectedlogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + } + logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Fatalf("expected error for k=0, got %v", logits) + } +} + +func TestTopP(t *testing.T) { + logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} + if !floats.Same(logits, expectedLogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + } + logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Fatalf("expected error for p=1.0, got %v", logits) + } + logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Fatalf("expected error for p=0.0, got %v", logits) + } +} + +func TestMinP(t *testing.T) { + logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + if err != nil { + t.Fatal(err) + } + expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} + if !floats.Same(logits, expectedLogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + } + logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + if err == nil { + t.Fatalf("expected error for p=1.0, got %v", logits) + } + logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + if err == nil { + t.Fatalf("expected error for p=0.0, got %v", logits) + } +} + +func TestWeighed(t *testing.T) { + logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()}) + if err != nil { + t.Fatal(err) + } + expectedLogits := []float64{1} + if !floats.Equal(logits, expectedLogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + } + logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()}) + if err == nil { + t.Fatalf("expected error for no valid tokens, got %v", logits) + } +} + +func TestSample(t *testing.T) { + input := []float64{1, 2, 3, 4} + expectedOutput := []float64{1, 2, 3, 4} + + var callOrder []int + mock1 := &mockSampler{ + id: 1, + callOrder: &callOrder, + returnVals: expectedOutput, + } + mock2 := &mockSampler{ + id: 2, + callOrder: &callOrder, + returnVals: expectedOutput, + } + mock3 := &mockSampler{ + id: 3, + callOrder: &callOrder, + returnVals: expectedOutput, + } + + result, err := Sample(input, mock1, mock2, mock3) + if err != nil { + t.Fatal(err) + } + + if !slices.Equal(callOrder, []int{1, 2, 3}) { + t.Errorf("Expected call order [1,2,3], got %v", callOrder) + } + + if !floats.Equal(result, expectedOutput) { + t.Errorf("Expected output %v, got %v", expectedOutput, result) + } + + errMock := &mockSampler{ + returnErr: fmt.Errorf("mock error"), + } + _, err = Sample(input, mock1, errMock, mock2) + if err == nil { + t.Error("Expected error from sampler") + } +} + +type mockSampler struct { + id int + callOrder *[]int + returnVals []float64 + returnErr error +} + +func (m *mockSampler) Sample(logits []float64) ([]float64, error) { + if m.callOrder != nil { + *m.callOrder = append(*m.callOrder, m.id) + } + if m.returnErr != nil { + return nil, m.returnErr + } + return m.returnVals, nil +}