sample: add sampling package for new engine (#8410)
This commit is contained in:
parent
314573bfe8
commit
0b7e1676eb
@ -65,8 +65,8 @@ type Sequence struct {
|
|||||||
// number of tokens to predict
|
// number of tokens to predict
|
||||||
numPredict int
|
numPredict int
|
||||||
|
|
||||||
// set of samplers to run on generated logits
|
// sampler with transforms to run on generated logits
|
||||||
samplers []sample.Sampler
|
sampler sample.Sampler
|
||||||
|
|
||||||
// channel to send back the embedding if embedding only
|
// channel to send back the embedding if embedding only
|
||||||
embedding chan []float32
|
embedding chan []float32
|
||||||
@ -93,7 +93,7 @@ type NewSequenceParams struct {
|
|||||||
numPredict int
|
numPredict int
|
||||||
stop []string
|
stop []string
|
||||||
numKeep int32
|
numKeep int32
|
||||||
samplers []sample.Sampler
|
sampler sample.Sampler
|
||||||
embedding bool
|
embedding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
responses: make(chan string, 100),
|
responses: make(chan string, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplers: params.samplers,
|
sampler: params.sampler,
|
||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
|
|||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
return fmt.Errorf("failed to decode batch: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s := modelOutput.Floats()
|
logits := modelOutput.Floats()
|
||||||
|
|
||||||
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
|
|
||||||
logits := make([]float64, len(f32s))
|
|
||||||
for i, f32 := range f32s {
|
|
||||||
logits[i] = float64(f32)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
@ -433,14 +427,12 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(f32s) / len(options.Outputs)
|
vocabSize := len(logits) / len(options.Outputs)
|
||||||
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(jessegross): Sampler will output a single int32 in the future
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
token := int32(tokens[0])
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to sample token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
@ -565,27 +557,6 @@ type CompletionResponse struct {
|
|||||||
Timings Timings `json:"timings"`
|
Timings Timings `json:"timings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSamplers(_ CompletionRequest) []sample.Sampler {
|
|
||||||
// TODO(jessegross): Waiting for sampling code
|
|
||||||
|
|
||||||
/*samplingParams.TopK = req.TopK
|
|
||||||
samplingParams.TopP = req.TopP
|
|
||||||
samplingParams.MinP = req.MinP
|
|
||||||
samplingParams.TypicalP = req.TypicalP
|
|
||||||
samplingParams.Temp = req.Temperature
|
|
||||||
samplingParams.RepeatLastN = req.RepeatLastN
|
|
||||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
|
||||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
|
||||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
|
||||||
samplingParams.Mirostat = req.Mirostat
|
|
||||||
samplingParams.MirostatTau = req.MirostatTau
|
|
||||||
samplingParams.MirostatEta = req.MirostatEta
|
|
||||||
samplingParams.Seed = uint32(req.Seed)
|
|
||||||
samplingParams.Grammar = req.Grammar*/
|
|
||||||
|
|
||||||
return []sample.Sampler{sample.Greedy()}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var req CompletionRequest
|
var req CompletionRequest
|
||||||
req.Options = Options(api.DefaultOptions())
|
req.Options = Options(api.DefaultOptions())
|
||||||
@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sampler, err := sample.NewSampler(
|
||||||
|
req.Temperature,
|
||||||
|
req.TopK,
|
||||||
|
req.TopP,
|
||||||
|
req.MinP,
|
||||||
|
req.Seed,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
numKeep: int32(req.NumKeep),
|
numKeep: int32(req.NumKeep),
|
||||||
samplers: getSamplers(req),
|
sampler: sampler,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import "gonum.org/v1/gonum/floats"
|
|
||||||
|
|
||||||
type greedy struct{}
|
|
||||||
|
|
||||||
func Greedy() Sampler {
|
|
||||||
return greedy{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s greedy) Sample(t []float64) ([]float64, error) {
|
|
||||||
return []float64{float64(floats.MaxIdx(t))}, nil
|
|
||||||
}
|
|
@ -1,74 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"gonum.org/v1/gonum/floats"
|
|
||||||
"gonum.org/v1/gonum/stat/sampleuv"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Sampler interface {
|
|
||||||
Sample([]float64) ([]float64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Temperature float64
|
|
||||||
|
|
||||||
func (s Temperature) Sample(t []float64) ([]float64, error) {
|
|
||||||
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type softmax struct{}
|
|
||||||
|
|
||||||
func Softmax() Sampler {
|
|
||||||
return softmax{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (softmax) Sample(t []float64) ([]float64, error) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TopK int
|
|
||||||
|
|
||||||
func (s TopK) Sample(t []float64) ([]float64, error) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TopP float32
|
|
||||||
|
|
||||||
func (s TopP) Sample(t []float64) ([]float64, error) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MinP float32
|
|
||||||
|
|
||||||
func (s MinP) Sample(t []float64) ([]float64, error) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type weighed struct{}
|
|
||||||
|
|
||||||
func Weighed() Sampler {
|
|
||||||
return weighed{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s weighed) Sample(t []float64) ([]float64, error) {
|
|
||||||
w := sampleuv.NewWeighted(t, nil)
|
|
||||||
if v, ok := w.Take(); ok {
|
|
||||||
return []float64{float64(v)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
|
|
||||||
var err error
|
|
||||||
for _, sampler := range samplers {
|
|
||||||
floats, err = sampler.Sample(floats)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return floats, nil
|
|
||||||
}
|
|
139
sample/samplers.go
Normal file
139
sample/samplers.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"gonum.org/v1/gonum/stat/sampleuv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sampler interface {
|
||||||
|
Sample([]float32) (int32, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type weighted struct {
|
||||||
|
src rand.Source
|
||||||
|
transforms []Transform
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
|
||||||
|
func Weighted(seed *uint64, transforms ...Transform) Sampler {
|
||||||
|
var src rand.Source
|
||||||
|
if seed != nil {
|
||||||
|
src = rand.NewSource(*seed)
|
||||||
|
}
|
||||||
|
return weighted{src: src, transforms: transforms}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s weighted) Sample(logits []float32) (int32, error) {
|
||||||
|
logits64 := make([]float64, len(logits))
|
||||||
|
for i, v := range logits {
|
||||||
|
logits64[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range s.transforms {
|
||||||
|
logits64 = t.Apply(logits64)
|
||||||
|
}
|
||||||
|
|
||||||
|
logitsCopy := make([]float64, 0, len(logits))
|
||||||
|
indices := make([]int, 0, len(logits))
|
||||||
|
for i, logit := range logits64 {
|
||||||
|
if !math.IsInf(logit, -1) {
|
||||||
|
logitsCopy = append(logitsCopy, logit)
|
||||||
|
indices = append(indices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(logitsCopy) == 0 {
|
||||||
|
return -1, errors.New("no valid logits found for weighed sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
probs := softmax(logitsCopy)
|
||||||
|
w := sampleuv.NewWeighted(probs, s.src)
|
||||||
|
if idx, ok := w.Take(); ok {
|
||||||
|
return int32(indices[idx]), nil
|
||||||
|
}
|
||||||
|
return -1, errors.New("weighed sampler failed, no valid token found")
|
||||||
|
}
|
||||||
|
|
||||||
|
type greedy struct {
|
||||||
|
transforms []Transform
|
||||||
|
}
|
||||||
|
|
||||||
|
func Greedy(transforms ...Transform) Sampler {
|
||||||
|
return greedy{transforms: transforms}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||||
|
logits64 := make([]float64, len(logits))
|
||||||
|
for i, v := range logits {
|
||||||
|
logits64[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range s.transforms {
|
||||||
|
logits64 = t.Apply(logits64)
|
||||||
|
}
|
||||||
|
|
||||||
|
var maxIdx int
|
||||||
|
var maxLogit float64
|
||||||
|
for i, logit := range logits64 {
|
||||||
|
if logit > maxLogit {
|
||||||
|
maxLogit = logit
|
||||||
|
maxIdx = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxLogit == math.Inf(-1) {
|
||||||
|
return -1, errors.New("no valid logits found for greedy sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
return int32(maxIdx), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||||
|
transforms := []Transform{}
|
||||||
|
if temperature < 0 || temperature > 2 {
|
||||||
|
return nil, errors.New("temperature must be between 0 and 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
if temperature != 0 {
|
||||||
|
transforms = append(transforms, Temperature(temperature))
|
||||||
|
}
|
||||||
|
|
||||||
|
if topK != 0 {
|
||||||
|
if topK <= 0 {
|
||||||
|
return nil, errors.New("topK must be greater than 0")
|
||||||
|
}
|
||||||
|
transforms = append(transforms, TopK(topK))
|
||||||
|
}
|
||||||
|
|
||||||
|
if topP != 0 {
|
||||||
|
if topP < 0 || topP >= 1 {
|
||||||
|
return nil, errors.New("topP must be between 0 and 1")
|
||||||
|
}
|
||||||
|
transforms = append(transforms, TopP(topP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if minP != 0 {
|
||||||
|
if minP < 0 || minP >= 1 {
|
||||||
|
return nil, errors.New("minP must be between 0 and 1")
|
||||||
|
}
|
||||||
|
transforms = append(transforms, MinP(minP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(transforms) == 0 {
|
||||||
|
return nil, errors.New("at least one transform is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if temperature == 0 {
|
||||||
|
return Greedy(transforms...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if seed != 0 {
|
||||||
|
seed64 := uint64(seed)
|
||||||
|
return Weighted(&seed64, transforms...), nil
|
||||||
|
}
|
||||||
|
return Weighted(nil, transforms...), nil
|
||||||
|
}
|
238
sample/samplers_test.go
Normal file
238
sample/samplers_test.go
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWeighted(t *testing.T) {
|
||||||
|
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
want := int32(1)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for no valid tokens, got index", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
seed := uint64(42)
|
||||||
|
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// With seed 42, we expect a consistent sample
|
||||||
|
want = int32(3) // This will be deterministic due to the seed
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testTransform struct {
|
||||||
|
id int
|
||||||
|
callOrder *[]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *testTransform) Apply(logits []float64) []float64 {
|
||||||
|
if ts.callOrder != nil {
|
||||||
|
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||||
|
}
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSample(t *testing.T) {
|
||||||
|
input := []float32{1, 2, 3, 4}
|
||||||
|
|
||||||
|
var callOrder []int
|
||||||
|
mock1 := &testTransform{
|
||||||
|
id: 1,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
mock2 := &testTransform{
|
||||||
|
id: 2,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
mock3 := &testTransform{
|
||||||
|
id: 3,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
want := int32(3) // Greedy sampler should pick highest logit
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
wantOrder := []int{1, 2, 3}
|
||||||
|
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||||
|
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
callOrder = nil
|
||||||
|
|
||||||
|
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wantOrder = []int{1, 2, 3}
|
||||||
|
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||||
|
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSampler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
temperature float32
|
||||||
|
topK int
|
||||||
|
topP float32
|
||||||
|
minP float32
|
||||||
|
seed int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no transforms",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temperature",
|
||||||
|
temperature: 0.5,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature negative",
|
||||||
|
temperature: -1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature too high",
|
||||||
|
temperature: 2.1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top k",
|
||||||
|
topK: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top k negative",
|
||||||
|
topK: -1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top p",
|
||||||
|
topP: 0.9,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top p negative",
|
||||||
|
topP: -0.1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top p one",
|
||||||
|
topP: 1.0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "min p",
|
||||||
|
minP: 0.2,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min p negative",
|
||||||
|
minP: -0.1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min p one",
|
||||||
|
minP: 1.0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "seed",
|
||||||
|
seed: 42,
|
||||||
|
wantErr: true, // seed alone is not valid without other transforms
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default values",
|
||||||
|
temperature: 0.8,
|
||||||
|
topK: 40,
|
||||||
|
topP: 0.9,
|
||||||
|
minP: 0.0,
|
||||||
|
seed: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all zeroes",
|
||||||
|
temperature: 0.0,
|
||||||
|
topK: 0,
|
||||||
|
topP: 0.0,
|
||||||
|
minP: 0.0,
|
||||||
|
seed: 0,
|
||||||
|
wantErr: true, // all zeroes means no transforms
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all transforms",
|
||||||
|
temperature: 0.8,
|
||||||
|
topK: 50,
|
||||||
|
topP: 0.95,
|
||||||
|
minP: 0.1,
|
||||||
|
seed: 42,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
transforms := []Transform{
|
||||||
|
Temperature(0.5),
|
||||||
|
TopK(10),
|
||||||
|
TopP(0.9),
|
||||||
|
MinP(0.2),
|
||||||
|
}
|
||||||
|
|
||||||
|
samplers := map[string]Sampler{
|
||||||
|
"Greedy": Greedy(transforms...),
|
||||||
|
"Weighted": Weighted(nil, transforms...),
|
||||||
|
}
|
||||||
|
|
||||||
|
logits := make([]float32, 1<<16)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = rand.Float32()
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, s := range samplers {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
if _, err := s.Sample(logits); err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
120
sample/transforms.go
Normal file
120
sample/transforms.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Transform interface {
|
||||||
|
Apply([]float64) []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): potentially cache softmax values
|
||||||
|
func softmax(logits []float64) []float64 {
|
||||||
|
var sum float64
|
||||||
|
probs := make([]float64, len(logits))
|
||||||
|
for i, v := range logits {
|
||||||
|
probs[i] = math.Exp(v)
|
||||||
|
sum += probs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range probs {
|
||||||
|
probs[i] /= sum
|
||||||
|
}
|
||||||
|
|
||||||
|
return probs
|
||||||
|
}
|
||||||
|
|
||||||
|
type Temperature float64
|
||||||
|
|
||||||
|
func (t Temperature) Apply(logits []float64) []float64 {
|
||||||
|
temp := math.Max(float64(t), 1e-7)
|
||||||
|
|
||||||
|
// subtracting max logit to avoid under/overflow
|
||||||
|
maxLogit := slices.Max(logits)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = (logits[i] - maxLogit) / temp
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
type logitMap struct {
|
||||||
|
index int
|
||||||
|
logit float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type TopK int
|
||||||
|
|
||||||
|
// TODO(parthsareen): avoid having to check all logits after this transform
|
||||||
|
func (k TopK) Apply(logits []float64) []float64 {
|
||||||
|
if int(k) >= len(logits) {
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
q := pq.NewWith(func(a, b logitMap) int {
|
||||||
|
return -cmp.Compare(a.logit, b.logit)
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, logit := range logits {
|
||||||
|
q.Enqueue(logitMap{index: i, logit: logit})
|
||||||
|
}
|
||||||
|
|
||||||
|
validLogits := make(map[int]float64)
|
||||||
|
for range k {
|
||||||
|
logitMap, _ := q.Dequeue()
|
||||||
|
validLogits[logitMap.index] = logitMap.logit
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range logits {
|
||||||
|
if _, ok := validLogits[i]; !ok {
|
||||||
|
logits[i] = math.Inf(-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
type TopP float64
|
||||||
|
|
||||||
|
func (p TopP) Apply(logits []float64) []float64 {
|
||||||
|
probs := softmax(logits)
|
||||||
|
indices := make([]int, len(probs))
|
||||||
|
for i := range indices {
|
||||||
|
indices[i] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort in descending order
|
||||||
|
slices.SortFunc(indices, func(i, j int) int {
|
||||||
|
return cmp.Compare(probs[j], probs[i])
|
||||||
|
})
|
||||||
|
|
||||||
|
var sum float64
|
||||||
|
for i, idx := range indices {
|
||||||
|
sum += probs[idx]
|
||||||
|
if sum > float64(p) {
|
||||||
|
for _, idx := range indices[i+1:] {
|
||||||
|
logits[idx] = math.Inf(-1)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
type MinP float64
|
||||||
|
|
||||||
|
func (p MinP) Apply(logits []float64) []float64 {
|
||||||
|
probs := softmax(logits)
|
||||||
|
threshold := slices.Max(probs) * float64(p)
|
||||||
|
|
||||||
|
for i, prob := range probs {
|
||||||
|
if prob < threshold {
|
||||||
|
logits[i] = math.Inf(-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits
|
||||||
|
}
|
80
sample/transforms_test.go
Normal file
80
sample/transforms_test.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTemperature(t *testing.T) {
|
||||||
|
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
||||||
|
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoftmax(t *testing.T) {
|
||||||
|
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
|
|
||||||
|
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTopK(t *testing.T) {
|
||||||
|
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
|
|
||||||
|
want = []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTopP(t *testing.T) {
|
||||||
|
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinP(t *testing.T) {
|
||||||
|
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
||||||
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTransform(b *testing.B) {
|
||||||
|
transforms := map[string]Transform{
|
||||||
|
"Temperature": Temperature(0.5),
|
||||||
|
"TopK": TopK(10),
|
||||||
|
"TopP": TopP(0.9),
|
||||||
|
"MinP": MinP(0.2),
|
||||||
|
}
|
||||||
|
|
||||||
|
logits := make([]float64, 1<<16)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, transform := range transforms {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
transform.Apply(logits)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user