improve temperature sampler

This commit is contained in:
ParthSareen 2025-01-20 13:42:23 -08:00
parent b91487f289
commit 7cd9fbbbb1
2 changed files with 56 additions and 32 deletions

View File

@ -16,16 +16,19 @@ type Sampler interface {
type Temperature float64 type Temperature float64
func (s Temperature) Sample(logits []float64) ([]float64, error) { func (t Temperature) Sample(logits []float64) ([]float64, error) {
if s < 0 || s > 1 { if t < 0 || t > 2 {
return nil, errors.New("temperature must be between 0 and 1") return nil, errors.New("temperature must be between 0 and 2")
} }
// greedy sampling // subtracting max logit to avoid under/overflow
if s == 0 { maxLogit := floats.Max(logits)
return []float64{floats.Max(logits)}, nil
temp := math.Max(float64(t), 1e-7)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
} }
floats.Scale(1.0/float64(s), logits)
return logits, nil return logits, nil
} }
@ -47,10 +50,8 @@ func computeSoftmax(logits []float64) ([]float64, error) {
} }
floatSum := floats.Sum(copiedLogits) floatSum := floats.Sum(copiedLogits)
if floatSum == 0 {
return nil, errors.New("no valid tokens found")
}
floats.Scale(1.0/floatSum, copiedLogits) floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits, nil return copiedLogits, nil
} }
@ -175,9 +176,28 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
return nil, errors.New("weighed sampler failed") return nil, errors.New("weighed sampler failed")
} }
// TODO: remove after next PR merge
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
func (greedy) Sample(logits []float64) ([]float64, error) {
return []float64{float64(floats.MaxIdx(logits))}, nil
}
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) { func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
var err error var err error
for _, sampler := range samplers { for _, sampler := range samplers {
if sampler == Temperature(0) {
// early return with greedy if temperature is 0
logits, err = Greedy().Sample(logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err = sampler.Sample(logits) logits, err = sampler.Sample(logits)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -19,17 +19,7 @@ func TestTemperature(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := []float64{-6, -4, -2, 0, 2, 4, 8} want := []float64{-14, -12, -10, -8, -6, -4, 0}
if !floats.Equal(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want)
}
// Only expect the max value returned
logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
want = []float64{4}
if !floats.Equal(logits, want) { if !floats.Equal(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want) t.Fatalf("got: %v, want: %v", logits, want)
} }
@ -37,6 +27,9 @@ func TestTemperature(t *testing.T) {
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=-1, got %v", logits) t.Fatalf("expected error for temperature=-1, got %v", logits)
} }
if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=2.1, got %v", logits)
}
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
@ -124,17 +117,17 @@ func TestSample(t *testing.T) {
want := []float64{1, 2, 3, 4} want := []float64{1, 2, 3, 4}
var callOrder []int var callOrder []int
mock1 := &mockSampler{ mock1 := &testSampler{
id: 1, id: 1,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want, returnVals: want,
} }
mock2 := &mockSampler{ mock2 := &testSampler{
id: 2, id: 2,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want, returnVals: want,
} }
mock3 := &mockSampler{ mock3 := &testSampler{
id: 3, id: 3,
callOrder: &callOrder, callOrder: &callOrder,
returnVals: want, returnVals: want,
@ -153,7 +146,7 @@ func TestSample(t *testing.T) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
} }
errMock := &mockSampler{ errMock := &testSampler{
returnErr: fmt.Errorf("mock error"), returnErr: fmt.Errorf("mock error"),
} }
_, err = Sample(input, mock1, errMock, mock2) _, err = Sample(input, mock1, errMock, mock2)
@ -162,19 +155,30 @@ func TestSample(t *testing.T) {
} }
} }
type mockSampler struct { type testSampler struct {
id int id int
callOrder *[]int callOrder *[]int
returnVals []float64 returnVals []float64
returnErr error returnErr error
} }
func (m *mockSampler) Sample(logits []float64) ([]float64, error) { func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
if m.callOrder != nil { if ts.callOrder != nil {
*m.callOrder = append(*m.callOrder, m.id) *ts.callOrder = append(*ts.callOrder, ts.id)
} }
if m.returnErr != nil { if ts.returnErr != nil {
return nil, m.returnErr return nil, ts.returnErr
}
return ts.returnVals, nil
}
func TestSampleTemperatureZero(t *testing.T) {
logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
if err != nil {
t.Fatal(err)
}
want := []float64{3}
if !floats.Equal(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want)
} }
return m.returnVals, nil
} }