From c245b0406fd669bc8e3aea4e20148fa303fe2fd4 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 27 Feb 2025 15:44:53 -0800 Subject: [PATCH] sample: remove transforms from greedy sampling (#9377) --- sample/samplers.go | 53 ++++++++---------------- sample/samplers_test.go | 89 ++++++++++++++++++----------------------- 2 files changed, 55 insertions(+), 87 deletions(-) diff --git a/sample/samplers.go b/sample/samplers.go index 836c6e4d9..1b8a5edd9 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) { if idx, ok := w.Take(); ok { return int32(indices[idx]), nil } - return -1, errors.New("weighed sampler failed, no valid token found") + return -1, errors.New("weighted sampler failed, no valid token found") } -type greedy struct { - transforms []Transform -} - -func Greedy(transforms ...Transform) Sampler { - return greedy{transforms: transforms} +type greedy struct{} + +func Greedy() Sampler { + return greedy{} } +// Sample returns the index of the maximum value in logits. func (s greedy) Sample(logits []float32) (int32, error) { - logits64 := make([]float64, len(logits)) - for i, v := range logits { - logits64[i] = float64(v) + if len(logits) == 0 { + return -1, errors.New("no logits provided for greedy sampling") } - for _, t := range s.transforms { - logits64 = t.Apply(logits64) - } - - var maxIdx int - var maxLogit float64 - for i, logit := range logits64 { - if logit > maxLogit { - maxLogit = logit + maxIdx := 0 + for i := range logits { + if logits[i] > logits[maxIdx] { maxIdx = i } } - if maxLogit == math.Inf(-1) { - return -1, errors.New("no valid logits found for greedy sampling") - } - return int32(maxIdx), nil } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { - transforms := []Transform{} + if temperature == 0 { + return Greedy(), nil + } + if temperature < 0 || temperature > 2 { return nil, errors.New("temperature must be between 0 and 2") } - if temperature != 0 { - transforms = append(transforms, Temperature(temperature)) - } + transforms := []Transform{Temperature(temperature)} if topK != 0 { if topK <= 0 { @@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed transforms = append(transforms, MinP(minP)) } - if len(transforms) == 0 { - return nil, errors.New("at least one transform is required") - } - - if temperature == 0 { - return Greedy(transforms...), nil - } - - if seed != 0 { + if seed >= 0 { seed64 := uint64(seed) return Weighted(&seed64, transforms...), nil } diff --git a/sample/samplers_test.go b/sample/samplers_test.go index aaa8d99c4..32364a3b7 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -66,32 +66,15 @@ func TestSample(t *testing.T) { callOrder: &callOrder, } - got, err := Greedy(mock1, mock2, mock3).Sample(input) + _, err := Weighted(nil, mock1, mock2, mock3).Sample(input) if err != nil { t.Error(err) return } - - want := int32(3) // Greedy sampler should pick highest logit - if want != got { - t.Errorf("index mismatch: want %d, got %d", want, got) - } wantOrder := []int{1, 2, 3} if diff := cmp.Diff(wantOrder, callOrder); diff != "" { t.Errorf("call order mismatch (-want +got):\n%s", diff) } - - callOrder = nil - - _, err = Weighted(nil, mock1, mock2, mock3).Sample(input) - if err != nil { - t.Error(err) - return - } - wantOrder = []int{1, 2, 3} - if diff := cmp.Diff(wantOrder, callOrder); diff != "" { - t.Errorf("call order mismatch (-want +got):\n%s", diff) - } } func TestNewSampler(t *testing.T) { @@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) { wantErr bool }{ { - name: "no transforms", - wantErr: true, + name: "no transforms", + // temperature is 0, so greedy should be used + wantErr: false, }, { name: "temperature", @@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) { wantErr: true, }, { - name: "top k", - topK: 10, - wantErr: false, + name: "top k", + topK: 10, + temperature: 0.8, + wantErr: false, }, { - name: "invalid top k negative", - topK: -1, - wantErr: true, + name: "invalid top k negative", + topK: -1, + temperature: 0.8, + wantErr: true, }, { - name: "top p", - topP: 0.9, - wantErr: false, + name: "top p", + topP: 0.9, + temperature: 0.8, + wantErr: false, }, { - name: "invalid top p negative", - topP: -0.1, - wantErr: true, + name: "invalid top p negative", + topP: -0.1, + temperature: 0.8, + wantErr: true, }, { - name: "invalid top p one", - topP: 1.0, - wantErr: true, + name: "invalid top p one", + topP: 1.0, + temperature: 0.8, + wantErr: true, }, { - name: "min p", - minP: 0.2, - wantErr: false, + name: "min p", + minP: 0.2, + temperature: 0.8, + wantErr: false, }, { - name: "invalid min p negative", - minP: -0.1, - wantErr: true, + name: "invalid min p negative", + minP: -0.1, + temperature: 0.8, + wantErr: true, }, { - name: "invalid min p one", - minP: 1.0, - wantErr: true, - }, - { - name: "seed", - seed: 42, - wantErr: true, // seed alone is not valid without other transforms + name: "invalid min p one", + minP: 1.0, + temperature: 0.8, + wantErr: true, }, { name: "default values", @@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) { topP: 0.0, minP: 0.0, seed: 0, - wantErr: true, // all zeroes means no transforms + wantErr: false, // all zeroes means no transforms }, { name: "all transforms", @@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) { } samplers := map[string]Sampler{ - "Greedy": Greedy(transforms...), + "Greedy": Greedy(), "Weighted": Weighted(nil, transforms...), }