Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
0de5bbd0fe sample: use json unmarshal for sampling params 2025-03-20 15:03:42 -04:00
4 changed files with 75 additions and 43 deletions

View File

@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
)
sampler := sample.NewSampler(req.Options, grammar)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,

View File

@ -1,12 +1,14 @@
package sample
import (
"encoding/json"
"errors"
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
// SamplerParams contains the validated and normalized parameters for a sampler
type SamplerParams struct {
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
Seed int `json:"seed"`
}
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
type rawParams SamplerParams
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
return err
}
// Validate and normalize after unmarshaling
if p.Temperature < 0.0 {
p.Temperature = 0.0
}
if p.TopP < 0.0 {
p.TopP = 0.0
}
if p.TopP >= 1.0 {
p.TopP = 1.0
}
if p.MinP < 0.0 {
p.MinP = 0.0
}
if p.MinP >= 1.0 {
p.MinP = 1.0
}
return nil
}
// NewSampler creates a new sampler with the given options
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
var params SamplerParams
data, _ := json.Marshal(opts)
_ = json.Unmarshal(data, &params)
var rng *rand.Rand
if seed != -1 {
if params.Seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
sequence := uint64(params.Seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
if temperature < 0.0 {
temperature = 0.0
}
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
}
if minP < 0.0 {
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
topK: params.TopK,
topP: params.TopP,
minP: params.MinP,
temperature: params.Temperature,
grammar: grammar,
}
}

View File

@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
sampler.Sample(logits)
b.ResetTimer()
@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
b.ResetTimer()
for b.Loop() {
@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
b.ResetTimer()
for b.Loop() {

View File

@ -4,11 +4,23 @@ import (
"math"
"math/rand/v2"
"testing"
"github.com/ollama/ollama/api"
)
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
return &api.Options{
Temperature: temperature,
TopK: topK,
TopP: topP,
MinP: minP,
Seed: seed,
}
}
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) {
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) {
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
}
// Generate random logits for benchmarking