improve temperature sampler
This commit is contained in:
parent
b91487f289
commit
7cd9fbbbb1
@ -16,16 +16,19 @@ type Sampler interface {
|
|||||||
|
|
||||||
type Temperature float64
|
type Temperature float64
|
||||||
|
|
||||||
func (s Temperature) Sample(logits []float64) ([]float64, error) {
|
func (t Temperature) Sample(logits []float64) ([]float64, error) {
|
||||||
if s < 0 || s > 1 {
|
if t < 0 || t > 2 {
|
||||||
return nil, errors.New("temperature must be between 0 and 1")
|
return nil, errors.New("temperature must be between 0 and 2")
|
||||||
}
|
}
|
||||||
|
|
||||||
// greedy sampling
|
// subtracting max logit to avoid under/overflow
|
||||||
if s == 0 {
|
maxLogit := floats.Max(logits)
|
||||||
return []float64{floats.Max(logits)}, nil
|
|
||||||
|
temp := math.Max(float64(t), 1e-7)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = (logits[i] - maxLogit) / temp
|
||||||
}
|
}
|
||||||
floats.Scale(1.0/float64(s), logits)
|
|
||||||
return logits, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,10 +50,8 @@ func computeSoftmax(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
floatSum := floats.Sum(copiedLogits)
|
floatSum := floats.Sum(copiedLogits)
|
||||||
if floatSum == 0 {
|
|
||||||
return nil, errors.New("no valid tokens found")
|
|
||||||
}
|
|
||||||
floats.Scale(1.0/floatSum, copiedLogits)
|
floats.Scale(1.0/floatSum, copiedLogits)
|
||||||
|
|
||||||
return copiedLogits, nil
|
return copiedLogits, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,9 +176,28 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
|||||||
return nil, errors.New("weighed sampler failed")
|
return nil, errors.New("weighed sampler failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: remove after next PR merge
|
||||||
|
type greedy struct{}
|
||||||
|
|
||||||
|
func Greedy() Sampler {
|
||||||
|
return greedy{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (greedy) Sample(logits []float64) ([]float64, error) {
|
||||||
|
return []float64{float64(floats.MaxIdx(logits))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
||||||
var err error
|
var err error
|
||||||
for _, sampler := range samplers {
|
for _, sampler := range samplers {
|
||||||
|
if sampler == Temperature(0) {
|
||||||
|
// early return with greedy if temperature is 0
|
||||||
|
logits, err = Greedy().Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
logits, err = sampler.Sample(logits)
|
logits, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -19,17 +19,7 @@ func TestTemperature(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
want := []float64{-6, -4, -2, 0, 2, 4, 8}
|
want := []float64{-14, -12, -10, -8, -6, -4, 0}
|
||||||
if !floats.Equal(logits, want) {
|
|
||||||
t.Fatalf("got: %v, want: %v", logits, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only expect the max value returned
|
|
||||||
logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
want = []float64{4}
|
|
||||||
if !floats.Equal(logits, want) {
|
if !floats.Equal(logits, want) {
|
||||||
t.Fatalf("got: %v, want: %v", logits, want)
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
@ -37,6 +27,9 @@ func TestTemperature(t *testing.T) {
|
|||||||
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
||||||
t.Fatalf("expected error for temperature=-1, got %v", logits)
|
t.Fatalf("expected error for temperature=-1, got %v", logits)
|
||||||
}
|
}
|
||||||
|
if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
||||||
|
t.Fatalf("expected error for temperature=2.1, got %v", logits)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestSoftmax(t *testing.T) {
|
||||||
@ -124,17 +117,17 @@ func TestSample(t *testing.T) {
|
|||||||
want := []float64{1, 2, 3, 4}
|
want := []float64{1, 2, 3, 4}
|
||||||
|
|
||||||
var callOrder []int
|
var callOrder []int
|
||||||
mock1 := &mockSampler{
|
mock1 := &testSampler{
|
||||||
id: 1,
|
id: 1,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: want,
|
returnVals: want,
|
||||||
}
|
}
|
||||||
mock2 := &mockSampler{
|
mock2 := &testSampler{
|
||||||
id: 2,
|
id: 2,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: want,
|
returnVals: want,
|
||||||
}
|
}
|
||||||
mock3 := &mockSampler{
|
mock3 := &testSampler{
|
||||||
id: 3,
|
id: 3,
|
||||||
callOrder: &callOrder,
|
callOrder: &callOrder,
|
||||||
returnVals: want,
|
returnVals: want,
|
||||||
@ -153,7 +146,7 @@ func TestSample(t *testing.T) {
|
|||||||
t.Errorf("got %v, want %v", got, want)
|
t.Errorf("got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
errMock := &mockSampler{
|
errMock := &testSampler{
|
||||||
returnErr: fmt.Errorf("mock error"),
|
returnErr: fmt.Errorf("mock error"),
|
||||||
}
|
}
|
||||||
_, err = Sample(input, mock1, errMock, mock2)
|
_, err = Sample(input, mock1, errMock, mock2)
|
||||||
@ -162,19 +155,30 @@ func TestSample(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockSampler struct {
|
type testSampler struct {
|
||||||
id int
|
id int
|
||||||
callOrder *[]int
|
callOrder *[]int
|
||||||
returnVals []float64
|
returnVals []float64
|
||||||
returnErr error
|
returnErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSampler) Sample(logits []float64) ([]float64, error) {
|
func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
|
||||||
if m.callOrder != nil {
|
if ts.callOrder != nil {
|
||||||
*m.callOrder = append(*m.callOrder, m.id)
|
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||||
}
|
}
|
||||||
if m.returnErr != nil {
|
if ts.returnErr != nil {
|
||||||
return nil, m.returnErr
|
return nil, ts.returnErr
|
||||||
|
}
|
||||||
|
return ts.returnVals, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleTemperatureZero(t *testing.T) {
|
||||||
|
logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := []float64{3}
|
||||||
|
if !floats.Equal(logits, want) {
|
||||||
|
t.Fatalf("got: %v, want: %v", logits, want)
|
||||||
}
|
}
|
||||||
return m.returnVals, nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user