sample: use json unmarshal for sampling params
This commit is contained in:
parent
42a14f7f63
commit
0de5bbd0fe
@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
sampler := sample.NewSampler(req.Options, grammar)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
|
@ -1,12 +1,14 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
return tokens[idx], nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
// SamplerParams contains the validated and normalized parameters for a sampler
|
||||
type SamplerParams struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
Seed int `json:"seed"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
|
||||
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
|
||||
type rawParams SamplerParams
|
||||
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate and normalize after unmarshaling
|
||||
if p.Temperature < 0.0 {
|
||||
p.Temperature = 0.0
|
||||
}
|
||||
|
||||
if p.TopP < 0.0 {
|
||||
p.TopP = 0.0
|
||||
}
|
||||
if p.TopP >= 1.0 {
|
||||
p.TopP = 1.0
|
||||
}
|
||||
|
||||
if p.MinP < 0.0 {
|
||||
p.MinP = 0.0
|
||||
}
|
||||
if p.MinP >= 1.0 {
|
||||
p.MinP = 1.0
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSampler creates a new sampler with the given options
|
||||
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
|
||||
var params SamplerParams
|
||||
data, _ := json.Marshal(opts)
|
||||
_ = json.Unmarshal(data, ¶ms)
|
||||
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
if params.Seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
// Use original seed for sequence
|
||||
sequence := uint64(seed)
|
||||
sequence := uint64(params.Seed)
|
||||
// Use golden ratio hash to generate statistically independent seeds
|
||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||
}
|
||||
if temperature < 0.0 {
|
||||
temperature = 0.0
|
||||
}
|
||||
|
||||
if topP < 0.0 {
|
||||
topP = 0.0
|
||||
}
|
||||
if topP >= 1.0 {
|
||||
topP = 1.0
|
||||
}
|
||||
|
||||
if minP < 0.0 {
|
||||
minP = 0.0
|
||||
}
|
||||
if minP >= 1.0 {
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
topK: params.TopK,
|
||||
topP: params.TopP,
|
||||
minP: params.MinP,
|
||||
temperature: params.Temperature,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
sampler.Sample(logits)
|
||||
@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
||||
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
|
@ -4,11 +4,23 @@ import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
|
||||
return &api.Options{
|
||||
Temperature: temperature,
|
||||
TopK: topK,
|
||||
TopP: topP,
|
||||
MinP: minP,
|
||||
Seed: seed,
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) {
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) {
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
Loading…
x
Reference in New Issue
Block a user