Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
0de5bbd0fe sample: use json unmarshal for sampling params 2025-03-20 15:03:42 -04:00
4 changed files with 75 additions and 43 deletions

View File

@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
sampler := sample.NewSampler( sampler := sample.NewSampler(req.Options, grammar)
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict, numPredict: req.Options.NumPredict,

View File

@ -1,12 +1,14 @@
package sample package sample
import ( import (
"encoding/json"
"errors" "errors"
"math" "math"
"math/rand/v2" "math/rand/v2"
"slices" "slices"
"sync" "sync"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return tokens[idx], nil return tokens[idx], nil
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // SamplerParams contains the validated and normalized parameters for a sampler
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { type SamplerParams struct {
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
Seed int `json:"seed"`
}
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
type rawParams SamplerParams
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
return err
}
// Validate and normalize after unmarshaling
if p.Temperature < 0.0 {
p.Temperature = 0.0
}
if p.TopP < 0.0 {
p.TopP = 0.0
}
if p.TopP >= 1.0 {
p.TopP = 1.0
}
if p.MinP < 0.0 {
p.MinP = 0.0
}
if p.MinP >= 1.0 {
p.MinP = 1.0
}
return nil
}
// NewSampler creates a new sampler with the given options
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
var params SamplerParams
data, _ := json.Marshal(opts)
_ = json.Unmarshal(data, &params)
var rng *rand.Rand var rng *rand.Rand
if seed != -1 { if params.Seed != -1 {
// PCG requires two parameters: sequence and stream // PCG requires two parameters: sequence and stream
// Use original seed for sequence // Use original seed for sequence
sequence := uint64(seed) sequence := uint64(params.Seed)
// Use golden ratio hash to generate statistically independent seeds // Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
} }
if temperature < 0.0 {
temperature = 0.0
}
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
}
if minP < 0.0 {
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
return Sampler{ return Sampler{
rng: rng, rng: rng,
topK: topK, topK: params.TopK,
topP: topP, topP: params.TopP,
minP: minP, minP: params.MinP,
temperature: temperature, temperature: params.Temperature,
grammar: grammar, grammar: grammar,
} }
} }

View File

@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5) logits[i] = float32(rand.Float64()*10 - 5)
} }
sampler := NewSampler(0.8, 0, 0, 0, 42, nil) sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
sampler.Sample(logits) sampler.Sample(logits)
@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs { for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) { b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil) sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
sampler.Sample(logits) sampler.Sample(logits)
b.ResetTimer() b.ResetTimer()
@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly // Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) { b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil) sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5) logits[i] = float32(rand.Float64()*10 - 5)
} }
sampler := NewSampler(0, -1, 0, 0, -1, nil) sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {

View File

@ -4,11 +4,23 @@ import (
"math" "math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
"github.com/ollama/ollama/api"
) )
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
return &api.Options{
Temperature: temperature,
TopK: topK,
TopP: topP,
MinP: minP,
Seed: seed,
}
}
func TestWeighted(t *testing.T) { func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10} logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil) sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err := sampler.Sample(logits) got, err := sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) {
} }
logits = []float32{-100, -10, 0, 10} logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil) sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) {
// Test very high p // Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1} logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens // Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) {
} }
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())} logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil) sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err == nil { if err == nil {
t.Errorf("expected error, got %d", got) t.Errorf("expected error, got %d", got)
@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) {
func BenchmarkSample(b *testing.B) { func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{ samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy "Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), "Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
} }
// Generate random logits for benchmarking // Generate random logits for benchmarking