Working json sampler

This commit is contained in:
ParthSareen 2025-01-15 14:08:54 -08:00
parent 5b19d4941a
commit b91487f289

104
sample/json_sampler.go Normal file
View File

@ -0,0 +1,104 @@
package sample
import (
"fmt"
"math"
"github.com/ollama/ollama/model"
)
type JSONState int
const (
StateStart JSONState = iota // Initial state
StateInObject // Inside an object {}
StateInArray // Inside an array []
StateInString // Inside a string ""
StateAfterKey // After object key, expecting :
StateAfterColon // After :, expecting value
StateAfterValue // After value, expecting , or closing bracket
StateDone // JSON parsing complete
)
type JSONSampler struct {
state JSONState
stack []string
proc model.TextProcessor
}
func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
return &JSONSampler{
state: StateStart,
proc: proc,
}
}
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
// Pre-decode valid tokens for current state
validTokens := make(map[uint32]bool)
// Always allow EOS token in any state
// TODO: Check for other special tokens if needed
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
validTokens[uint32(i)] = true
}
}
// Build set of valid tokens based on current state
switch s.state {
case StateStart:
// Only allow opening brace
for i := range logits {
text, err := s.proc.Decode([]int32{int32(i)})
if err == nil && text == "{" {
validTokens[uint32(i)] = true
}
}
case StateInObject, StateInArray:
// Allow any token
for i := range logits {
validTokens[uint32(i)] = true
}
case StateInString:
// Allow any token except closing brace
for i := range logits {
text, err := s.proc.Decode([]int32{int32(i)})
if err == nil && text != "}" {
validTokens[uint32(i)] = true
}
}
case StateDone:
// No tokens allowed
}
// Mark invalid tokens as NaN in one pass
for i := range logits {
if !validTokens[uint32(i)] {
logits[i] = math.NaN()
}
}
return logits, nil
}
func (s *JSONSampler) UpdateState(tokenID int) error {
text, err := s.proc.Decode([]int32{int32(tokenID)})
if err != nil {
return fmt.Errorf("failed to decode token: %w", err)
}
switch s.state {
case StateStart:
if text != "{" {
return fmt.Errorf("expected {, got %s", text)
}
s.state = StateInObject
case StateInObject:
if text == "}" {
s.state = StateDone
}
case StateDone:
return fmt.Errorf("unexpected token after closing bracket: %s", text)
}
return nil
}