diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a1a549cd..9b68dc56d 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } - sampler := sample.NewSampler( - req.Options.Temperature, - req.Options.TopK, - req.Options.TopP, - req.Options.MinP, - req.Options.Seed, - grammar, - ) + sampler := sample.NewSampler(req.Options, grammar) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.Options.NumPredict, diff --git a/sample/samplers.go b/sample/samplers.go index ef8033691..5cd43b9f5 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -1,12 +1,14 @@ package sample import ( + "encoding/json" "errors" "math" "math/rand/v2" "slices" "sync" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" ) @@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) { return tokens[idx], nil } -// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { +// SamplerParams contains the validated and normalized parameters for a sampler +type SamplerParams struct { + Temperature float32 `json:"temperature"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + Seed int `json:"seed"` +} + +// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling +func (p *SamplerParams) UnmarshalJSON(data []byte) error { + type rawParams SamplerParams + if err := json.Unmarshal(data, (*rawParams)(p)); err != nil { + return err + } + + // Validate and normalize after unmarshaling + if p.Temperature < 0.0 { + p.Temperature = 0.0 + } + + if p.TopP < 0.0 { + p.TopP = 0.0 + } + if p.TopP >= 1.0 { + p.TopP = 1.0 + } + + if p.MinP < 0.0 { + p.MinP = 0.0 + } + if p.MinP >= 1.0 { + p.MinP = 1.0 + } + + return nil +} + +// NewSampler creates a new sampler with the given options +func NewSampler(opts *api.Options, grammar *Grammar) Sampler { + var params SamplerParams + data, _ := json.Marshal(opts) + _ = json.Unmarshal(data, ¶ms) + var rng *rand.Rand - if seed != -1 { + if params.Seed != -1 { // PCG requires two parameters: sequence and stream // Use original seed for sequence - sequence := uint64(seed) + sequence := uint64(params.Seed) // Use golden ratio hash to generate statistically independent seeds rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) } - if temperature < 0.0 { - temperature = 0.0 - } - - if topP < 0.0 { - topP = 0.0 - } - if topP >= 1.0 { - topP = 1.0 - } - - if minP < 0.0 { - minP = 0.0 - } - if minP >= 1.0 { - minP = 1.0 - } return Sampler{ rng: rng, - topK: topK, - topP: topP, - minP: minP, - temperature: temperature, + topK: params.TopK, + topP: params.TopP, + minP: params.MinP, + temperature: params.Temperature, grammar: grammar, } } diff --git a/sample/samplers_benchmark_test.go b/sample/samplers_benchmark_test.go index cd1380141..1f15747b2 100644 --- a/sample/samplers_benchmark_test.go +++ b/sample/samplers_benchmark_test.go @@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) { logits[i] = float32(rand.Float64()*10 - 5) } - sampler := NewSampler(0.8, 0, 0, 0, 42, nil) + sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil) b.ResetTimer() for b.Loop() { sampler.Sample(logits) @@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) { for _, tc := range configs { b.Run("Config"+tc.name, func(b *testing.B) { - sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil) + sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil) sampler.Sample(logits) b.ResetTimer() @@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) { // Test with combined transforms separately - topK influences performance greatly b.Run("TransformCombined", func(b *testing.B) { - sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil) + sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil) b.ResetTimer() for b.Loop() { @@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) { logits[i] = float32(rand.Float64()*10 - 5) } - sampler := NewSampler(0, -1, 0, 0, -1, nil) + sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil) b.ResetTimer() for b.Loop() { diff --git a/sample/samplers_test.go b/sample/samplers_test.go index d79dce474..0bc37890c 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -4,11 +4,23 @@ import ( "math" "math/rand/v2" "testing" + + "github.com/ollama/ollama/api" ) +func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options { + return &api.Options{ + Temperature: temperature, + TopK: topK, + TopP: topP, + MinP: minP, + Seed: seed, + } +} + func TestWeighted(t *testing.T) { logits := []float32{-10, 3, -10, -10} - sampler := NewSampler(0, 0, 0, 0, 0, nil) + sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil) got, err := sampler.Sample(logits) if err != nil { t.Error(err) @@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) { } logits = []float32{-100, -10, 0, 10} - sampler = NewSampler(0, 0, 0, 0, 0, nil) + sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil) got, err = sampler.Sample(logits) if err != nil { t.Error(err) @@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) { // Test very high p logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1} // Use extremely small topP to filter out all tokens - sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) + sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil) got, err = sampler.Sample(logits) if err != nil { t.Error(err) @@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) { } logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())} - sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil) + sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil) got, err = sampler.Sample(logits) if err == nil { t.Errorf("expected error, got %d", got) @@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) { func BenchmarkSample(b *testing.B) { samplers := map[string]Sampler{ - "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy - "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), + "Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy + "Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil), } // Generate random logits for benchmarking