Improved sampler interface

This commit is contained in:
ParthSareen 2025-02-03 16:04:32 -08:00
parent d5f8670f0a
commit cc2e44b885
3 changed files with 120 additions and 97 deletions

View File

@ -8,6 +8,6 @@ func Greedy() Sampler {
return greedy{} return greedy{}
} }
func (s greedy) Sample(t []float64) ([]float64, error) { func (s greedy) Sample(t []float64) (int, error) {
return []float64{float64(floats.MaxIdx(t))}, nil return floats.MaxIdx(t), nil
} }

View File

@ -10,13 +10,30 @@ import (
"gonum.org/v1/gonum/stat/sampleuv" "gonum.org/v1/gonum/stat/sampleuv"
) )
type Transform interface {
Apply([]float64) ([]float64, error)
}
type Sampler interface { type Sampler interface {
Sample([]float64) ([]float64, error) Sample([]float64) (int, error)
}
type SamplerConfig struct {
transforms []Transform
sampler Sampler
}
// NewSampler creates a sampler with the given transforms and sampling method
func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig {
return &SamplerConfig{
transforms: transforms,
sampler: sampler,
}
} }
type Temperature float64 type Temperature float64
func (t Temperature) Sample(logits []float64) ([]float64, error) { func (t Temperature) Apply(logits []float64) ([]float64, error) {
if t < 0 || t > 2 { if t < 0 || t > 2 {
return nil, errors.New("temperature must be between 0 and 2") return nil, errors.New("temperature must be between 0 and 2")
} }
@ -34,15 +51,16 @@ func (t Temperature) Sample(logits []float64) ([]float64, error) {
type softmax struct{} type softmax struct{}
func Softmax() Sampler { func Softmax() Transform {
return softmax{} return softmax{}
} }
func (softmax) Sample(logits []float64) ([]float64, error) { func (softmax) Apply(logits []float64) ([]float64, error) {
return computeSoftmax(logits) return computeSoftmax(logits), nil
} }
func computeSoftmax(logits []float64) ([]float64, error) { // TODO: cache softmax values
func computeSoftmax(logits []float64) []float64 {
copiedLogits := make([]float64, len(logits)) copiedLogits := make([]float64, len(logits))
copy(copiedLogits, logits) copy(copiedLogits, logits)
for i := range copiedLogits { for i := range copiedLogits {
@ -52,12 +70,12 @@ func computeSoftmax(logits []float64) ([]float64, error) {
floatSum := floats.Sum(copiedLogits) floatSum := floats.Sum(copiedLogits)
floats.Scale(1.0/floatSum, copiedLogits) floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits, nil return copiedLogits
} }
type TopK int type TopK int
func (k TopK) Sample(logits []float64) ([]float64, error) { func (k TopK) Apply(logits []float64) ([]float64, error) {
if k <= 0 { if k <= 0 {
return nil, errors.New("k must be positive") return nil, errors.New("k must be positive")
} }
@ -76,23 +94,20 @@ func (k TopK) Sample(logits []float64) ([]float64, error) {
}) })
for _, idx := range indices[k:] { for _, idx := range indices[k:] {
logits[idx] = math.NaN() logits[idx] = math.Inf(-1)
} }
return logits, nil return logits, nil
} }
type TopP float32 type TopP float64
func (p TopP) Sample(logits []float64) ([]float64, error) { func (p TopP) Apply(logits []float64) ([]float64, error) {
if p <= 0 || p >= 1 { if p <= 0 || p >= 1 {
return nil, errors.New("p must be between 0 and 1") return nil, errors.New("p must be between 0 and 1")
} }
probs, err := computeSoftmax(logits) probs := computeSoftmax(logits)
if err != nil {
return nil, err
}
indices := make([]int, len(probs)) indices := make([]int, len(probs))
for i := range indices { for i := range indices {
@ -104,12 +119,12 @@ func (p TopP) Sample(logits []float64) ([]float64, error) {
return cmp.Compare(probs[j], probs[i]) return cmp.Compare(probs[j], probs[i])
}) })
cumSum := 0.0 var cumSum float64
for i, idx := range indices { for i, idx := range indices {
cumSum += probs[idx] cumSum += probs[idx]
if cumSum > float64(p) { if cumSum > float64(p) {
for _, idx := range indices[i+1:] { for _, idx := range indices[i+1:] {
logits[idx] = math.NaN() logits[idx] = math.Inf(-1)
} }
break break
} }
@ -117,17 +132,14 @@ func (p TopP) Sample(logits []float64) ([]float64, error) {
return logits, nil return logits, nil
} }
type MinP float32 type MinP float64
func (p MinP) Sample(logits []float64) ([]float64, error) { func (p MinP) Apply(logits []float64) ([]float64, error) {
if p <= 0 || p >= 1 { if p <= 0 || p >= 1 {
return nil, errors.New("p must be between 0 and 1") return nil, errors.New("p must be between 0 and 1")
} }
probs, err := computeSoftmax(logits) probs := computeSoftmax(logits)
if err != nil {
return nil, err
}
copiedProbs := make([]float64, len(probs)) copiedProbs := make([]float64, len(probs))
copy(copiedProbs, probs) copy(copiedProbs, probs)
@ -138,7 +150,7 @@ func (p MinP) Sample(logits []float64) ([]float64, error) {
for i := range probs { for i := range probs {
if probs[i] < probThreshold { if probs[i] < probThreshold {
logits[i] = math.NaN() logits[i] = math.Inf(-1)
} }
} }
@ -151,48 +163,51 @@ func Weighed() Sampler {
return weighed{} return weighed{}
} }
func (s weighed) Sample(logits []float64) ([]float64, error) { // should return single value
func (s weighed) Sample(logits []float64) (int, error) {
logitsCopy := make([]float64, 0, len(logits)) logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits)) indices := make([]int, 0, len(logits))
// the uv sampler does not support NaN values // the uv sampler does not support NaN values
for i, logit := range logits { for i, logit := range logits {
if !math.IsNaN(logit) { if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit) logitsCopy = append(logitsCopy, logit)
indices = append(indices, i) indices = append(indices, i)
} }
} }
if len(logitsCopy) == 0 { if len(logitsCopy) == 0 {
return nil, errors.New("no valid tokens found") return -1, errors.New("no valid tokens found")
} }
softmax, err := computeSoftmax(logitsCopy) softmax := computeSoftmax(logitsCopy)
if err != nil {
return nil, err
}
w := sampleuv.NewWeighted(softmax, nil) w := sampleuv.NewWeighted(softmax, nil)
if v, ok := w.Take(); ok { if idx, ok := w.Take(); ok {
// returns the token ID // returns the token ID
return []float64{float64(indices[v])}, nil return indices[idx], nil
} }
return nil, errors.New("weighed sampler failed") return -1, errors.New("weighed sampler failed")
}
// Sample applies transforms and samples a token ID
func (s *SamplerConfig) Sample(input []float32) (int, error) {
logits := make([]float64, len(input))
for i, v := range input {
logits[i] = float64(v)
} }
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
var err error var err error
for _, sampler := range samplers { for _, t := range s.transforms {
if sampler == Temperature(0) { if t == Temperature(0) {
// early return with greedy if temperature is 0 // early return with greedy if temperature is 0
logits, err = Greedy().Sample(logits) s.sampler = Greedy()
break
}
logits, err = t.Apply(logits)
if err != nil { if err != nil {
return nil, err return -1, err
}
return logits, nil
}
logits, err = sampler.Sample(logits)
if err != nil {
return nil, err
} }
} }
return logits, nil
return s.sampler.Sample(logits)
} }

View File

@ -10,7 +10,7 @@ import (
) )
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -19,16 +19,16 @@ func TestTemperature(t *testing.T) {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", logits, want)
} }
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=-1, got %v", logits) t.Fatalf("expected error for temperature=-1, got %v", logits)
} }
if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=2.1, got %v", logits) t.Fatalf("expected error for temperature=2.1, got %v", logits)
} }
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -40,95 +40,101 @@ func TestSoftmax(t *testing.T) {
} }
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4} expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if !floats.Same(logits, expectedlogits) { if !floats.Same(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
} }
logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for k=0, got %v", logits) t.Fatalf("expected error for k=0, got %v", logits)
} }
logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
if !floats.Same(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
}
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if !floats.Same(logits, want) { if !floats.Same(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", logits, want)
} }
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits) t.Fatalf("expected error for p=1.0, got %v", logits)
} }
logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits) t.Fatalf("expected error for p=0.0, got %v", logits)
} }
} }
func TestMinP(t *testing.T) { func TestMinP(t *testing.T) {
logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4}
if !floats.Same(logits, want) { if !floats.Same(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", logits, want)
} }
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits) t.Fatalf("expected error for p=1.0, got %v", logits)
} }
logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits) t.Fatalf("expected error for p=0.0, got %v", logits)
} }
} }
func TestWeighed(t *testing.T) { func TestWeighed(t *testing.T) {
logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()}) idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := []float64{1} want := 1
if !floats.Equal(logits, want) { if idx != want {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", idx, want)
} }
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()}) idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)})
if err == nil { if err == nil {
t.Fatalf("expected error for no valid tokens, got %v", logits) t.Fatalf("expected error for no valid tokens, got %v", idx)
} }
} }
func TestSample(t *testing.T) { func TestSample(t *testing.T) {
input := []float64{1, 2, 3, 4} input := []float32{1, 2, 3, 4}
want := []float64{1, 2, 3, 4}
var callOrder []int var callOrder []int
mock1 := &testSampler{ mock1 := &testTransform{
id: 1, id: 1,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want,
} }
mock2 := &testSampler{ mock2 := &testTransform{
id: 2, id: 2,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want,
} }
mock3 := &testSampler{ mock3 := &testTransform{
id: 3, id: 3,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want,
} }
sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy())
got, err := Sample(input, mock1, mock2, mock3) got, err := sampler.Sample(input)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -137,43 +143,45 @@ func TestSample(t *testing.T) {
t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3}) t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
} }
if !floats.Equal(got, want) { want := 3 // Greedy sampler should pick highest logit
if got != want {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
} }
errMock := &testSampler{ errMock := &testTransform{
returnErr: fmt.Errorf("mock error"), returnErr: fmt.Errorf("mock error"),
} }
_, err = Sample(input, mock1, errMock, mock2) sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy())
_, err = sampler.Sample(input)
if err == nil { if err == nil {
t.Error("Expected error from sampler") t.Error("Expected error from sampler")
} }
} }
type testSampler struct { type testTransform struct {
id int id int
callOrder *[]int callOrder *[]int
returnVals []float64
returnErr error returnErr error
} }
func (ts *testSampler) Sample(logits []float64) ([]float64, error) { func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
if ts.callOrder != nil { if ts.callOrder != nil {
*ts.callOrder = append(*ts.callOrder, ts.id) *ts.callOrder = append(*ts.callOrder, ts.id)
} }
if ts.returnErr != nil { if ts.returnErr != nil {
return nil, ts.returnErr return nil, ts.returnErr
} }
return ts.returnVals, nil return logits, nil
} }
func TestSampleTemperatureZero(t *testing.T) { func TestSampleTemperatureZero(t *testing.T) {
logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0)) sampler := NewSampler([]Transform{Temperature(0)}, Greedy())
got, err := sampler.Sample([]float32{1, 2, 3, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := []float64{3} want := 3 // Greedy sampler should pick highest logit index
if !floats.Equal(logits, want) { if got != want {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", got, want)
} }
} }