sample: use json unmarshal for sampling params
This commit is contained in:
parent
42a14f7f63
commit
0de5bbd0fe
@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(req.Options, grammar)
|
||||||
req.Options.Temperature,
|
|
||||||
req.Options.TopK,
|
|
||||||
req.Options.TopP,
|
|
||||||
req.Options.MinP,
|
|
||||||
req.Options.Seed,
|
|
||||||
grammar,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.Options.NumPredict,
|
numPredict: req.Options.NumPredict,
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// SamplerParams contains the validated and normalized parameters for a sampler
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
type SamplerParams struct {
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
MinP float32 `json:"min_p"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
|
||||||
|
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
|
||||||
|
type rawParams SamplerParams
|
||||||
|
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate and normalize after unmarshaling
|
||||||
|
if p.Temperature < 0.0 {
|
||||||
|
p.Temperature = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.TopP < 0.0 {
|
||||||
|
p.TopP = 0.0
|
||||||
|
}
|
||||||
|
if p.TopP >= 1.0 {
|
||||||
|
p.TopP = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MinP < 0.0 {
|
||||||
|
p.MinP = 0.0
|
||||||
|
}
|
||||||
|
if p.MinP >= 1.0 {
|
||||||
|
p.MinP = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSampler creates a new sampler with the given options
|
||||||
|
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
|
||||||
|
var params SamplerParams
|
||||||
|
data, _ := json.Marshal(opts)
|
||||||
|
_ = json.Unmarshal(data, ¶ms)
|
||||||
|
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if params.Seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
// Use original seed for sequence
|
// Use original seed for sequence
|
||||||
sequence := uint64(seed)
|
sequence := uint64(params.Seed)
|
||||||
// Use golden ratio hash to generate statistically independent seeds
|
// Use golden ratio hash to generate statistically independent seeds
|
||||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||||
}
|
}
|
||||||
if temperature < 0.0 {
|
|
||||||
temperature = 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if topP < 0.0 {
|
|
||||||
topP = 0.0
|
|
||||||
}
|
|
||||||
if topP >= 1.0 {
|
|
||||||
topP = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if minP < 0.0 {
|
|
||||||
minP = 0.0
|
|
||||||
}
|
|
||||||
if minP >= 1.0 {
|
|
||||||
minP = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: params.TopK,
|
||||||
topP: topP,
|
topP: params.TopP,
|
||||||
minP: minP,
|
minP: params.MinP,
|
||||||
temperature: temperature,
|
temperature: params.Temperature,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range configs {
|
for _, tc := range configs {
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
@ -4,11 +4,23 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
|
||||||
|
return &api.Options{
|
||||||
|
Temperature: temperature,
|
||||||
|
TopK: topK,
|
||||||
|
TopP: topP,
|
||||||
|
MinP: minP,
|
||||||
|
Seed: seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
logits := []float32{-10, 3, -10, -10}
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||||
got, err := sampler.Sample(logits)
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
logits = []float32{-100, -10, 0, 10}
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
// Test very high p
|
// Test very high p
|
||||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
// Use extremely small topP to filter out all tokens
|
// Use extremely small topP to filter out all tokens
|
||||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error, got %d", got)
|
t.Errorf("expected error, got %d", got)
|
||||||
@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
// Generate random logits for benchmarking
|
||||||
|
Loading…
x
Reference in New Issue
Block a user