diff --git a/sample/sample.go b/sample/sample.go index 41a8bd696..a735785f0 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -16,16 +16,19 @@ type Sampler interface { type Temperature float64 -func (s Temperature) Sample(logits []float64) ([]float64, error) { - if s < 0 || s > 1 { - return nil, errors.New("temperature must be between 0 and 1") +func (t Temperature) Sample(logits []float64) ([]float64, error) { + if t < 0 || t > 2 { + return nil, errors.New("temperature must be between 0 and 2") } - // greedy sampling - if s == 0 { - return []float64{floats.Max(logits)}, nil + // subtracting max logit to avoid under/overflow + maxLogit := floats.Max(logits) + + temp := math.Max(float64(t), 1e-7) + for i := range logits { + logits[i] = (logits[i] - maxLogit) / temp } - floats.Scale(1.0/float64(s), logits) + return logits, nil } @@ -47,10 +50,8 @@ func computeSoftmax(logits []float64) ([]float64, error) { } floatSum := floats.Sum(copiedLogits) - if floatSum == 0 { - return nil, errors.New("no valid tokens found") - } floats.Scale(1.0/floatSum, copiedLogits) + return copiedLogits, nil } @@ -175,9 +176,28 @@ func (s weighed) Sample(logits []float64) ([]float64, error) { return nil, errors.New("weighed sampler failed") } +// TODO: remove after next PR merge +type greedy struct{} + +func Greedy() Sampler { + return greedy{} +} + +func (greedy) Sample(logits []float64) ([]float64, error) { + return []float64{float64(floats.MaxIdx(logits))}, nil +} + func Sample(logits []float64, samplers ...Sampler) ([]float64, error) { var err error for _, sampler := range samplers { + if sampler == Temperature(0) { + // early return with greedy if temperature is 0 + logits, err = Greedy().Sample(logits) + if err != nil { + return nil, err + } + return logits, nil + } logits, err = sampler.Sample(logits) if err != nil { return nil, err diff --git a/sample/sample_test.go b/sample/sample_test.go index 3536b2934..8900e824f 100644 --- a/sample/sample_test.go +++ b/sample/sample_test.go @@ -19,17 +19,7 @@ func TestTemperature(t *testing.T) { if err != nil { t.Fatal(err) } - want := []float64{-6, -4, -2, 0, 2, 4, 8} - if !floats.Equal(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) - } - - // Only expect the max value returned - logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) - if err != nil { - t.Fatal(err) - } - want = []float64{4} + want := []float64{-14, -12, -10, -8, -6, -4, 0} if !floats.Equal(logits, want) { t.Fatalf("got: %v, want: %v", logits, want) } @@ -37,6 +27,9 @@ func TestTemperature(t *testing.T) { if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { t.Fatalf("expected error for temperature=-1, got %v", logits) } + if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { + t.Fatalf("expected error for temperature=2.1, got %v", logits) + } } func TestSoftmax(t *testing.T) { @@ -124,17 +117,17 @@ func TestSample(t *testing.T) { want := []float64{1, 2, 3, 4} var callOrder []int - mock1 := &mockSampler{ + mock1 := &testSampler{ id: 1, callOrder: &callOrder, returnVals: want, } - mock2 := &mockSampler{ + mock2 := &testSampler{ id: 2, callOrder: &callOrder, returnVals: want, } - mock3 := &mockSampler{ + mock3 := &testSampler{ id: 3, callOrder: &callOrder, returnVals: want, @@ -153,7 +146,7 @@ func TestSample(t *testing.T) { t.Errorf("got %v, want %v", got, want) } - errMock := &mockSampler{ + errMock := &testSampler{ returnErr: fmt.Errorf("mock error"), } _, err = Sample(input, mock1, errMock, mock2) @@ -162,19 +155,30 @@ func TestSample(t *testing.T) { } } -type mockSampler struct { +type testSampler struct { id int callOrder *[]int returnVals []float64 returnErr error } -func (m *mockSampler) Sample(logits []float64) ([]float64, error) { - if m.callOrder != nil { - *m.callOrder = append(*m.callOrder, m.id) +func (ts *testSampler) Sample(logits []float64) ([]float64, error) { + if ts.callOrder != nil { + *ts.callOrder = append(*ts.callOrder, ts.id) } - if m.returnErr != nil { - return nil, m.returnErr + if ts.returnErr != nil { + return nil, ts.returnErr + } + return ts.returnVals, nil +} + +func TestSampleTemperatureZero(t *testing.T) { + logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0)) + if err != nil { + t.Fatal(err) + } + want := []float64{3} + if !floats.Equal(logits, want) { + t.Fatalf("got: %v, want: %v", logits, want) } - return m.returnVals, nil }