json checkpoint

This commit is contained in:
ParthSareen 2025-01-21 17:48:12 -08:00
parent 089bbb537d
commit a7c8cc06da
2 changed files with 365 additions and 0 deletions

207
sample/fast_json.go Normal file
View File

@ -0,0 +1,207 @@
package sample
import (
"errors"
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInArray
StateInColon
StateInComma
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateEnd
)
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInString:
return "StateInString"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInColon:
return "StateInColon"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInArray:
return "StateInArray"
case StateEnd:
return "StateEnd"
case StateInComma:
return "StateInComma"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInStringEnd:
return "StateInStringEnd"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}
type JSONSampler struct {
curNode *Node
proc model.TextProcessor
stack []*Node
}
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
// fmt.Println("Creating new JSON sampler")
startNode, err := buildStateMachine(proc)
if err != nil {
return nil, err
}
js := &JSONSampler{
curNode: startNode,
proc: proc,
}
return js, nil
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
// fmt.Printf("Current state: %s\n", s.curNode.State)
// fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if slices.Equal(tokenSlice, validToken) {
s.curNode = node
// fmt.Printf("Transitioned to state: %s\n", node.State)
return nil
}
}
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
s.curNode = node
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
return nil
}
}
}
fmt.Println("invalid token ", tokenSlice)
return errors.New("invalid token")
}
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
var err error
switch s.curNode.State {
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
case StateInInt:
validStates := []int32{}
minus, err := s.proc.Encode("-")
if err != nil {
return nil, err
}
digits := make([][]int32, 10)
for i := 0; i < 10; i++ {
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
if err != nil {
return nil, err
}
}
// Allow "-" and digits 0-9 at start
for i := range logits {
for _, d := range digits {
if len(d) == 1 && int32(i) == d[0] {
validStates = append(validStates, int32(i))
}
}
if len(minus) == 1 && int32(i) == minus[0] {
validStates = append(validStates, int32(i))
}
}
return logits, nil
default:
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
}
}
func getValidStates(node *Node) []int32 {
validStates := []int32{}
for _, edge := range node.TransitionEdges {
for _, token := range edge {
validStates = append(validStates, token...)
}
}
return validStates
}
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
for i := range logits {
isValid := false
for _, token := range validStates {
if token == -1 {
// fmt.Println("Found sentinel token, returning unmasked logits")
return logits, nil
}
if i == int(token) {
// fmt.Printf("Found valid token: %d\n", token)
isValid = true
break
}
}
if !isValid {
logits[i] = math.NaN()
}
}
return logits, nil
}

158
sample/state_machine.go Normal file
View File

@ -0,0 +1,158 @@
package sample
import (
"fmt"
"github.com/ollama/ollama/model"
)
type token []int32
type Node struct {
State JSONState
TransitionEdges map[*Node][]token
}
func NewNode(state JSONState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[*Node][]token),
}
}
var (
startToken token
endToken token
stringToken token
objectKeyToken token
tabToken token
spaceToken token
newlineToken token
newlineSpace token
commaToken token
commaToken2 token
commaToken3 token
colonToken token
colonToken2 token
)
func initTokens(proc model.TextProcessor) error {
var err error
startToken, err = proc.Encode("{")
if err != nil {
return err
}
endToken, err = proc.Encode("}")
if err != nil {
return err
}
stringToken, err = proc.Encode("\"")
if err != nil {
return err
}
objectKeyToken, err = proc.Encode("\"")
if err != nil {
return err
}
tabToken, err = proc.Encode("\t")
if err != nil {
return err
}
spaceToken, err = proc.Encode(" ")
if err != nil {
return err
}
newlineToken, err = proc.Encode("\n")
if err != nil {
return err
}
newlineSpace, err = proc.Encode(" \n")
if err != nil {
return err
}
// TODO: figure out how to encode colon correctly
colonToken, err = proc.Encode("\":")
if err != nil {
return err
}
fmt.Println("colonToken", colonToken)
colonToken2, err = proc.Encode(":")
if err != nil {
return err
}
commaToken, err = proc.Encode(",")
if err != nil {
return err
}
commaToken2, err = proc.Encode("\",")
if err != nil {
return err
}
fmt.Println("commaToken2", commaToken2)
commaToken3, err = proc.Encode("\",\"")
if err != nil {
return err
}
return nil
}
func buildStateMachine(proc model.TextProcessor) (*Node, error) {
if err := initTokens(proc); err != nil {
return nil, err
}
startNode := NewNode(StateStart)
objectNode := NewNode(StateInObject)
objectKeyNode := NewNode(StateInObjectKey)
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
stringNode := NewNode(StateInString)
intNode := NewNode(StateInInt)
commaNode := NewNode(StateInComma)
colonNode := NewNode(StateInColon)
stringEndNode := NewNode(StateInStringEnd)
endNode := NewNode(StateEnd)
terminateNode := NewNode(StateTerminate)
sentinelToken := token([]int32{-1})
intSentinelToken := token([]int32{-2})
startNode.TransitionEdges[objectNode] = []token{startToken}
objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
// characterize end of object key
objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken}
objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken}
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
intNode.TransitionEdges[intNode] = []token{intSentinelToken}
intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
intNode.TransitionEdges[terminateNode] = []token{endToken}
commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken}
colonNode.TransitionEdges[stringNode] = []token{stringToken}
colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
stringNode.TransitionEdges[stringEndNode] = []token{stringToken}
// "\""," Case
stringNode.TransitionEdges[commaNode] = []token{commaToken2}
// "\"",\"" Case
stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3}
stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
stringEndNode.TransitionEdges[terminateNode] = []token{endToken}
endNode.TransitionEdges[terminateNode] = []token{endToken}
terminateNode.TransitionEdges[terminateNode] = []token{}
return startNode, nil
}