diff --git a/sample/sample.go b/sample/sample.go index cc816c7da..41a8bd696 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -1,9 +1,10 @@ package sample import ( + "cmp" "errors" "math" - "sort" + "slices" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/stat/sampleuv" @@ -20,13 +21,12 @@ func (s Temperature) Sample(logits []float64) ([]float64, error) { return nil, errors.New("temperature must be between 0 and 1") } - copiedLogits := append([]float64(nil), logits...) - // Greedy sampling + // greedy sampling if s == 0 { - return []float64{floats.Max(copiedLogits)}, nil + return []float64{floats.Max(logits)}, nil } - floats.Scale(1.0/float64(s), copiedLogits) - return copiedLogits, nil + floats.Scale(1.0/float64(s), logits) + return logits, nil } type softmax struct{} @@ -69,8 +69,9 @@ func (k TopK) Sample(logits []float64) ([]float64, error) { indices[i] = i } - sort.Slice(indices, func(i, j int) bool { - return logits[indices[i]] > logits[indices[j]] + // sort in descending order + slices.SortFunc(indices, func(i, j int) int { + return cmp.Compare(logits[j], logits[i]) }) for _, idx := range indices[k:] { @@ -96,8 +97,10 @@ func (p TopP) Sample(logits []float64) ([]float64, error) { for i := range indices { indices[i] = i } - sort.Slice(indices, func(i, j int) bool { - return probs[indices[i]] > probs[indices[j]] + + // sort in descending order + slices.SortFunc(indices, func(i, j int) int { + return cmp.Compare(probs[j], probs[i]) }) cumSum := 0.0 @@ -127,9 +130,9 @@ func (p MinP) Sample(logits []float64) ([]float64, error) { copiedProbs := make([]float64, len(probs)) copy(copiedProbs, probs) - sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] }) + slices.Sort(copiedProbs) - maxProb := floats.Max(probs) + maxProb := copiedProbs[len(copiedProbs)-1] probThreshold := float64(p) * maxProb for i := range probs { @@ -162,20 +165,23 @@ func (s weighed) Sample(logits []float64) ([]float64, error) { return nil, errors.New("no valid tokens found") } + // usually, a softmax is applied to sample from the logits + // in this case the uv sampler normalizes the logits so that the sum of the weights is 1 w := sampleuv.NewWeighted(logitsCopy, nil) if v, ok := w.Take(); ok { + // returns the token ID return []float64{float64(indices[v])}, nil } return nil, errors.New("weighed sampler failed") } -func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) { +func Sample(logits []float64, samplers ...Sampler) ([]float64, error) { var err error for _, sampler := range samplers { - tokenID, err = sampler.Sample(tokenID) + logits, err = sampler.Sample(logits) if err != nil { return nil, err } } - return tokenID, nil + return logits, nil } diff --git a/sample/sample_test.go b/sample/sample_test.go index 314e5dd6d..3536b2934 100644 --- a/sample/sample_test.go +++ b/sample/sample_test.go @@ -3,9 +3,14 @@ package sample import ( "fmt" "math" + "math/rand" + "os" + "runtime" "slices" "testing" + "runtime/trace" + "gonum.org/v1/gonum/floats" ) @@ -14,9 +19,9 @@ func TestTemperature(t *testing.T) { if err != nil { t.Fatal(err) } - expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8} - if !floats.Equal(logits, expectedlogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + want := []float64{-6, -4, -2, 0, 2, 4, 8} + if !floats.Equal(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } // Only expect the max value returned @@ -24,9 +29,9 @@ func TestTemperature(t *testing.T) { if err != nil { t.Fatal(err) } - expectedlogits = []float64{4} - if !floats.Equal(logits, expectedlogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + want = []float64{4} + if !floats.Equal(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { @@ -35,7 +40,7 @@ func TestTemperature(t *testing.T) { } func TestSoftmax(t *testing.T) { - probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4}) + probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { t.Fatal(err) } @@ -66,9 +71,9 @@ func TestTopP(t *testing.T) { if err != nil { t.Fatal(err) } - expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} - if !floats.Same(logits, expectedLogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} + if !floats.Same(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { @@ -85,9 +90,9 @@ func TestMinP(t *testing.T) { if err != nil { t.Fatal(err) } - expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} - if !floats.Same(logits, expectedLogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} + if !floats.Same(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err == nil { @@ -104,9 +109,9 @@ func TestWeighed(t *testing.T) { if err != nil { t.Fatal(err) } - expectedLogits := []float64{1} - if !floats.Equal(logits, expectedLogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits) + want := []float64{1} + if !floats.Equal(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()}) if err == nil { @@ -116,36 +121,36 @@ func TestWeighed(t *testing.T) { func TestSample(t *testing.T) { input := []float64{1, 2, 3, 4} - expectedOutput := []float64{1, 2, 3, 4} + want := []float64{1, 2, 3, 4} var callOrder []int mock1 := &mockSampler{ id: 1, callOrder: &callOrder, - returnVals: expectedOutput, + returnVals: want, } mock2 := &mockSampler{ id: 2, callOrder: &callOrder, - returnVals: expectedOutput, + returnVals: want, } mock3 := &mockSampler{ id: 3, callOrder: &callOrder, - returnVals: expectedOutput, + returnVals: want, } - result, err := Sample(input, mock1, mock2, mock3) + got, err := Sample(input, mock1, mock2, mock3) if err != nil { t.Fatal(err) } if !slices.Equal(callOrder, []int{1, 2, 3}) { - t.Errorf("Expected call order [1,2,3], got %v", callOrder) + t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3}) } - if !floats.Equal(result, expectedOutput) { - t.Errorf("Expected output %v, got %v", expectedOutput, result) + if !floats.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) } errMock := &mockSampler{