addressing comments + cleanup
This commit is contained in:
parent
5e73f24e16
commit
5b19d4941a
@ -1,9 +1,10 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"slices"
|
||||||
|
|
||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
"gonum.org/v1/gonum/stat/sampleuv"
|
"gonum.org/v1/gonum/stat/sampleuv"
|
||||||
@ -20,13 +21,12 @@ func (s Temperature) Sample(logits []float64) ([]float64, error) {
|
|||||||
return nil, errors.New("temperature must be between 0 and 1")
|
return nil, errors.New("temperature must be between 0 and 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
copiedLogits := append([]float64(nil), logits...)
|
// greedy sampling
|
||||||
// Greedy sampling
|
|
||||||
if s == 0 {
|
if s == 0 {
|
||||||
return []float64{floats.Max(copiedLogits)}, nil
|
return []float64{floats.Max(logits)}, nil
|
||||||
}
|
}
|
||||||
floats.Scale(1.0/float64(s), copiedLogits)
|
floats.Scale(1.0/float64(s), logits)
|
||||||
return copiedLogits, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type softmax struct{}
|
type softmax struct{}
|
||||||
@ -69,8 +69,9 @@ func (k TopK) Sample(logits []float64) ([]float64, error) {
|
|||||||
indices[i] = i
|
indices[i] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(indices, func(i, j int) bool {
|
// sort in descending order
|
||||||
return logits[indices[i]] > logits[indices[j]]
|
slices.SortFunc(indices, func(i, j int) int {
|
||||||
|
return cmp.Compare(logits[j], logits[i])
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, idx := range indices[k:] {
|
for _, idx := range indices[k:] {
|
||||||
@ -96,8 +97,10 @@ func (p TopP) Sample(logits []float64) ([]float64, error) {
|
|||||||
for i := range indices {
|
for i := range indices {
|
||||||
indices[i] = i
|
indices[i] = i
|
||||||
}
|
}
|
||||||
sort.Slice(indices, func(i, j int) bool {
|
|
||||||
return probs[indices[i]] > probs[indices[j]]
|
// sort in descending order
|
||||||
|
slices.SortFunc(indices, func(i, j int) int {
|
||||||
|
return cmp.Compare(probs[j], probs[i])
|
||||||
})
|
})
|
||||||
|
|
||||||
cumSum := 0.0
|
cumSum := 0.0
|
||||||
@ -127,9 +130,9 @@ func (p MinP) Sample(logits []float64) ([]float64, error) {
|
|||||||
copiedProbs := make([]float64, len(probs))
|
copiedProbs := make([]float64, len(probs))
|
||||||
copy(copiedProbs, probs)
|
copy(copiedProbs, probs)
|
||||||
|
|
||||||
sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] })
|
slices.Sort(copiedProbs)
|
||||||
|
|
||||||
maxProb := floats.Max(probs)
|
maxProb := copiedProbs[len(copiedProbs)-1]
|
||||||
probThreshold := float64(p) * maxProb
|
probThreshold := float64(p) * maxProb
|
||||||
|
|
||||||
for i := range probs {
|
for i := range probs {
|
||||||
@ -162,20 +165,23 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
|||||||
return nil, errors.New("no valid tokens found")
|
return nil, errors.New("no valid tokens found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// usually, a softmax is applied to sample from the logits
|
||||||
|
// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
|
||||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
w := sampleuv.NewWeighted(logitsCopy, nil)
|
||||||
if v, ok := w.Take(); ok {
|
if v, ok := w.Take(); ok {
|
||||||
|
// returns the token ID
|
||||||
return []float64{float64(indices[v])}, nil
|
return []float64{float64(indices[v])}, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("weighed sampler failed")
|
return nil, errors.New("weighed sampler failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) {
|
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
||||||
var err error
|
var err error
|
||||||
for _, sampler := range samplers {
|
for _, sampler := range samplers {
|
||||||
tokenID, err = sampler.Sample(tokenID)
|
logits, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tokenID, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,14 @@ package sample
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"runtime/trace"
|
||||||
|
|
||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,9 +19,9 @@ func TestTemperature(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8}
|
want := []float64{-6, -4, -2, 0, 2, 4, 8}
|
||||||
if !floats.Equal(logits, expectedlogits) {
|
if !floats.Equal(logits, want) {
|
||||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only expect the max value returned
|
// Only expect the max value returned
|
||||||
@ -24,9 +29,9 @@ func TestTemperature(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
expectedlogits = []float64{4}
|
want = []float64{4}
|
||||||
if !floats.Equal(logits, expectedlogits) {
|
if !floats.Equal(logits, want) {
|
||||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
||||||
@ -35,7 +40,7 @@ func TestTemperature(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestSoftmax(t *testing.T) {
|
||||||
probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -66,9 +71,9 @@ func TestTopP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
|
want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
|
||||||
if !floats.Same(logits, expectedLogits) {
|
if !floats.Same(logits, want) {
|
||||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -85,9 +90,9 @@ func TestMinP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
|
want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
|
||||||
if !floats.Same(logits, expectedLogits) {
|
if !floats.Same(logits, want) {
|
||||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -104,9 +109,9 @@ func TestWeighed(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
expectedLogits := []float64{1}
|
want := []float64{1}
|
||||||
if !floats.Equal(logits, expectedLogits) {
|
if !floats.Equal(logits, want) {
|
||||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
|
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -116,36 +121,36 @@ func TestWeighed(t *testing.T) {
|
|||||||
|
|
||||||
func TestSample(t *testing.T) {
|
func TestSample(t *testing.T) {
|
||||||
input := []float64{1, 2, 3, 4}
|
input := []float64{1, 2, 3, 4}
|
||||||
expectedOutput := []float64{1, 2, 3, 4}
|
want := []float64{1, 2, 3, 4}
|
||||||
|
|
||||||
var callOrder []int
|
var callOrder []int
|
||||||
mock1 := &mockSampler{
|
mock1 := &mockSampler{
|
||||||
id: 1,
|
id: 1,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: expectedOutput,
|
returnVals: want,
|
||||||
}
|
}
|
||||||
mock2 := &mockSampler{
|
mock2 := &mockSampler{
|
||||||
id: 2,
|
id: 2,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: expectedOutput,
|
returnVals: want,
|
||||||
}
|
}
|
||||||
mock3 := &mockSampler{
|
mock3 := &mockSampler{
|
||||||
id: 3,
|
id: 3,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: expectedOutput,
|
returnVals: want,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := Sample(input, mock1, mock2, mock3)
|
got, err := Sample(input, mock1, mock2, mock3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Equal(callOrder, []int{1, 2, 3}) {
|
if !slices.Equal(callOrder, []int{1, 2, 3}) {
|
||||||
t.Errorf("Expected call order [1,2,3], got %v", callOrder)
|
t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
|
||||||
}
|
}
|
||||||
|
|
||||||
if !floats.Equal(result, expectedOutput) {
|
if !floats.Equal(got, want) {
|
||||||
t.Errorf("Expected output %v, got %v", expectedOutput, result)
|
t.Errorf("got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
errMock := &mockSampler{
|
errMock := &mockSampler{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user