Working json sampler
This commit is contained in:
parent
5b19d4941a
commit
b91487f289
104
sample/json_sampler.go
Normal file
104
sample/json_sampler.go
Normal file
@ -0,0 +1,104 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
|
||||
const (
|
||||
StateStart JSONState = iota // Initial state
|
||||
StateInObject // Inside an object {}
|
||||
StateInArray // Inside an array []
|
||||
StateInString // Inside a string ""
|
||||
StateAfterKey // After object key, expecting :
|
||||
StateAfterColon // After :, expecting value
|
||||
StateAfterValue // After value, expecting , or closing bracket
|
||||
StateDone // JSON parsing complete
|
||||
)
|
||||
|
||||
type JSONSampler struct {
|
||||
state JSONState
|
||||
stack []string
|
||||
proc model.TextProcessor
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
|
||||
return &JSONSampler{
|
||||
state: StateStart,
|
||||
proc: proc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// Pre-decode valid tokens for current state
|
||||
validTokens := make(map[uint32]bool)
|
||||
|
||||
// Always allow EOS token in any state
|
||||
// TODO: Check for other special tokens if needed
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build set of valid tokens based on current state
|
||||
switch s.state {
|
||||
case StateStart:
|
||||
// Only allow opening brace
|
||||
for i := range logits {
|
||||
text, err := s.proc.Decode([]int32{int32(i)})
|
||||
if err == nil && text == "{" {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
case StateInObject, StateInArray:
|
||||
// Allow any token
|
||||
for i := range logits {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
case StateInString:
|
||||
// Allow any token except closing brace
|
||||
for i := range logits {
|
||||
text, err := s.proc.Decode([]int32{int32(i)})
|
||||
if err == nil && text != "}" {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
case StateDone:
|
||||
// No tokens allowed
|
||||
}
|
||||
|
||||
// Mark invalid tokens as NaN in one pass
|
||||
for i := range logits {
|
||||
if !validTokens[uint32(i)] {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenID int) error {
|
||||
text, err := s.proc.Decode([]int32{int32(tokenID)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode token: %w", err)
|
||||
}
|
||||
|
||||
switch s.state {
|
||||
case StateStart:
|
||||
if text != "{" {
|
||||
return fmt.Errorf("expected {, got %s", text)
|
||||
}
|
||||
s.state = StateInObject
|
||||
case StateInObject:
|
||||
if text == "}" {
|
||||
s.state = StateDone
|
||||
}
|
||||
case StateDone:
|
||||
return fmt.Errorf("unexpected token after closing bracket: %s", text)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user