sampling package

This commit is contained in:
ParthSareen 2025-01-10 17:49:39 -08:00
parent 4aac178cac
commit 5e73f24e16
2 changed files with 303 additions and 21 deletions

View File

@ -1,7 +1,9 @@
package sample
import (
"slices"
"errors"
"math"
"sort"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat/sampleuv"
@ -13,9 +15,18 @@ type Sampler interface {
type Temperature float64
func (s Temperature) Sample(t []float64) ([]float64, error) {
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
return t, nil
func (s Temperature) Sample(logits []float64) ([]float64, error) {
if s < 0 || s > 1 {
return nil, errors.New("temperature must be between 0 and 1")
}
copiedLogits := append([]float64(nil), logits...)
// Greedy sampling
if s == 0 {
return []float64{floats.Max(copiedLogits)}, nil
}
floats.Scale(1.0/float64(s), copiedLogits)
return copiedLogits, nil
}
type softmax struct{}
@ -24,26 +35,110 @@ func Softmax() Sampler {
return softmax{}
}
func (softmax) Sample(t []float64) ([]float64, error) {
return t, nil
func (softmax) Sample(logits []float64) ([]float64, error) {
return computeSoftmax(logits)
}
func computeSoftmax(logits []float64) ([]float64, error) {
copiedLogits := make([]float64, len(logits))
copy(copiedLogits, logits)
for i := range copiedLogits {
copiedLogits[i] = math.Exp(copiedLogits[i])
}
floatSum := floats.Sum(copiedLogits)
if floatSum == 0 {
return nil, errors.New("no valid tokens found")
}
floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits, nil
}
type TopK int
func (s TopK) Sample(t []float64) ([]float64, error) {
return t, nil
func (k TopK) Sample(logits []float64) ([]float64, error) {
if k <= 0 {
return nil, errors.New("k must be positive")
}
if int(k) >= len(logits) {
return logits, nil
}
indices := make([]int, len(logits))
for i := range indices {
indices[i] = i
}
sort.Slice(indices, func(i, j int) bool {
return logits[indices[i]] > logits[indices[j]]
})
for _, idx := range indices[k:] {
logits[idx] = math.NaN()
}
return logits, nil
}
type TopP float32
func (s TopP) Sample(t []float64) ([]float64, error) {
return t, nil
func (p TopP) Sample(logits []float64) ([]float64, error) {
if p <= 0 || p >= 1 {
return nil, errors.New("p must be between 0 and 1")
}
probs, err := computeSoftmax(logits)
if err != nil {
return nil, err
}
indices := make([]int, len(probs))
for i := range indices {
indices[i] = i
}
sort.Slice(indices, func(i, j int) bool {
return probs[indices[i]] > probs[indices[j]]
})
cumSum := 0.0
for i, idx := range indices {
cumSum += probs[idx]
if cumSum > float64(p) {
for _, idx := range indices[i+1:] {
logits[idx] = math.NaN()
}
break
}
}
return logits, nil
}
type MinP float32
func (s MinP) Sample(t []float64) ([]float64, error) {
return t, nil
func (p MinP) Sample(logits []float64) ([]float64, error) {
if p <= 0 || p >= 1 {
return nil, errors.New("p must be between 0 and 1")
}
probs, err := computeSoftmax(logits)
if err != nil {
return nil, err
}
copiedProbs := make([]float64, len(probs))
copy(copiedProbs, probs)
sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] })
maxProb := floats.Max(probs)
probThreshold := float64(p) * maxProb
for i := range probs {
if probs[i] < probThreshold {
logits[i] = math.NaN()
}
}
return logits, nil
}
type weighed struct{}
@ -52,23 +147,35 @@ func Weighed() Sampler {
return weighed{}
}
func (s weighed) Sample(t []float64) ([]float64, error) {
w := sampleuv.NewWeighted(t, nil)
if v, ok := w.Take(); ok {
return []float64{float64(v)}, nil
func (s weighed) Sample(logits []float64) ([]float64, error) {
logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits))
// the uv sampler does not support NaN values
for i, logit := range logits {
if !math.IsNaN(logit) {
logitsCopy = append(logitsCopy, logit)
indices = append(indices, i)
}
}
return t, nil
if len(logitsCopy) == 0 {
return nil, errors.New("no valid tokens found")
}
w := sampleuv.NewWeighted(logitsCopy, nil)
if v, ok := w.Take(); ok {
return []float64{float64(indices[v])}, nil
}
return nil, errors.New("weighed sampler failed")
}
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) {
var err error
for _, sampler := range samplers {
floats, err = sampler.Sample(floats)
tokenID, err = sampler.Sample(tokenID)
if err != nil {
return nil, err
}
}
return floats, nil
return tokenID, nil
}

175
sample/sample_test.go Normal file
View File

@ -0,0 +1,175 @@
package sample
import (
"fmt"
"math"
"slices"
"testing"
"gonum.org/v1/gonum/floats"
)
func TestTemperature(t *testing.T) {
logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8}
if !floats.Equal(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
}
// Only expect the max value returned
logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedlogits = []float64{4}
if !floats.Equal(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
}
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=-1, got %v", logits)
}
}
func TestSoftmax(t *testing.T) {
probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
if !floats.Equal(probs, expectedProbs) {
t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
}
}
func TestTopK(t *testing.T) {
logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4}
if !floats.Same(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
}
logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for k=0, got %v", logits)
}
}
func TestTopP(t *testing.T) {
logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
if !floats.Same(logits, expectedLogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
}
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits)
}
logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits)
}
}
func TestMinP(t *testing.T) {
logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err != nil {
t.Fatal(err)
}
expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
if !floats.Same(logits, expectedLogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
}
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits)
}
logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits)
}
}
func TestWeighed(t *testing.T) {
logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()})
if err != nil {
t.Fatal(err)
}
expectedLogits := []float64{1}
if !floats.Equal(logits, expectedLogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
}
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
if err == nil {
t.Fatalf("expected error for no valid tokens, got %v", logits)
}
}
func TestSample(t *testing.T) {
input := []float64{1, 2, 3, 4}
expectedOutput := []float64{1, 2, 3, 4}
var callOrder []int
mock1 := &mockSampler{
id: 1,
callOrder: &callOrder,
returnVals: expectedOutput,
}
mock2 := &mockSampler{
id: 2,
callOrder: &callOrder,
returnVals: expectedOutput,
}
mock3 := &mockSampler{
id: 3,
callOrder: &callOrder,
returnVals: expectedOutput,
}
result, err := Sample(input, mock1, mock2, mock3)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(callOrder, []int{1, 2, 3}) {
t.Errorf("Expected call order [1,2,3], got %v", callOrder)
}
if !floats.Equal(result, expectedOutput) {
t.Errorf("Expected output %v, got %v", expectedOutput, result)
}
errMock := &mockSampler{
returnErr: fmt.Errorf("mock error"),
}
_, err = Sample(input, mock1, errMock, mock2)
if err == nil {
t.Error("Expected error from sampler")
}
}
type mockSampler struct {
id int
callOrder *[]int
returnVals []float64
returnErr error
}
func (m *mockSampler) Sample(logits []float64) ([]float64, error) {
if m.callOrder != nil {
*m.callOrder = append(*m.callOrder, m.id)
}
if m.returnErr != nil {
return nil, m.returnErr
}
return m.returnVals, nil
}