json checkpoint
This commit is contained in:
parent
089bbb537d
commit
a7c8cc06da
207
sample/fast_json.go
Normal file
207
sample/fast_json.go
Normal file
@ -0,0 +1,207 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
|
||||
const (
|
||||
StateStart JSONState = iota
|
||||
StateInObject
|
||||
StateInObjectKey
|
||||
StateNewline
|
||||
StateTab
|
||||
StateSpace
|
||||
StateInString
|
||||
StateInInt
|
||||
StateInFloat
|
||||
StateInBool
|
||||
StateInNull
|
||||
StateInArray
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInStringEnd
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
StateEnd
|
||||
)
|
||||
|
||||
func (s JSONState) String() string {
|
||||
switch s {
|
||||
case StateStart:
|
||||
return "StateStart"
|
||||
case StateInObject:
|
||||
return "StateInObject"
|
||||
case StateInObjectKey:
|
||||
return "StateInObjectKey"
|
||||
case StateInString:
|
||||
return "StateInString"
|
||||
case StateNewline:
|
||||
return "StateNewline"
|
||||
case StateTab:
|
||||
return "StateTab"
|
||||
case StateSpace:
|
||||
return "StateSpace"
|
||||
case StateInInt:
|
||||
return "StateInInt"
|
||||
case StateInFloat:
|
||||
return "StateInFloat"
|
||||
case StateInColon:
|
||||
return "StateInColon"
|
||||
case StateInBool:
|
||||
return "StateInBool"
|
||||
case StateInNull:
|
||||
return "StateInNull"
|
||||
case StateInArray:
|
||||
return "StateInArray"
|
||||
case StateEnd:
|
||||
return "StateEnd"
|
||||
case StateInComma:
|
||||
return "StateInComma"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInStringEnd:
|
||||
return "StateInStringEnd"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
|
||||
type JSONSampler struct {
|
||||
curNode *Node
|
||||
proc model.TextProcessor
|
||||
stack []*Node
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
||||
// fmt.Println("Creating new JSON sampler")
|
||||
startNode, err := buildStateMachine(proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js := &JSONSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
}
|
||||
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
||||
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
||||
|
||||
// fmt.Println("tokenSlice", tokenSlice)
|
||||
// todo: account for strings here
|
||||
for node, edge := range s.curNode.TransitionEdges {
|
||||
for _, validToken := range edge {
|
||||
if slices.Equal(tokenSlice, validToken) {
|
||||
s.curNode = node
|
||||
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for node, edge := range s.curNode.TransitionEdges {
|
||||
for _, validToken := range edge {
|
||||
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
|
||||
s.curNode = node
|
||||
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("invalid token ", tokenSlice)
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
||||
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
|
||||
var err error
|
||||
|
||||
switch s.curNode.State {
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInInt:
|
||||
validStates := []int32{}
|
||||
minus, err := s.proc.Encode("-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
digits := make([][]int32, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Allow "-" and digits 0-9 at start
|
||||
for i := range logits {
|
||||
for _, d := range digits {
|
||||
if len(d) == 1 && int32(i) == d[0] {
|
||||
validStates = append(validStates, int32(i))
|
||||
}
|
||||
}
|
||||
if len(minus) == 1 && int32(i) == minus[0] {
|
||||
validStates = append(validStates, int32(i))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
validStates := getValidStates(s.curNode)
|
||||
logits, err = s.maskLogits(logits, validStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getValidStates(node *Node) []int32 {
|
||||
validStates := []int32{}
|
||||
for _, edge := range node.TransitionEdges {
|
||||
for _, token := range edge {
|
||||
validStates = append(validStates, token...)
|
||||
}
|
||||
}
|
||||
return validStates
|
||||
}
|
||||
|
||||
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
||||
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
||||
for i := range logits {
|
||||
isValid := false
|
||||
for _, token := range validStates {
|
||||
if token == -1 {
|
||||
// fmt.Println("Found sentinel token, returning unmasked logits")
|
||||
return logits, nil
|
||||
}
|
||||
if i == int(token) {
|
||||
// fmt.Printf("Found valid token: %d\n", token)
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isValid {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
158
sample/state_machine.go
Normal file
158
sample/state_machine.go
Normal file
@ -0,0 +1,158 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type token []int32
|
||||
|
||||
type Node struct {
|
||||
State JSONState
|
||||
TransitionEdges map[*Node][]token
|
||||
}
|
||||
|
||||
func NewNode(state JSONState) *Node {
|
||||
return &Node{
|
||||
State: state,
|
||||
TransitionEdges: make(map[*Node][]token),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
startToken token
|
||||
endToken token
|
||||
stringToken token
|
||||
objectKeyToken token
|
||||
tabToken token
|
||||
spaceToken token
|
||||
newlineToken token
|
||||
newlineSpace token
|
||||
commaToken token
|
||||
commaToken2 token
|
||||
commaToken3 token
|
||||
colonToken token
|
||||
colonToken2 token
|
||||
)
|
||||
|
||||
func initTokens(proc model.TextProcessor) error {
|
||||
var err error
|
||||
startToken, err = proc.Encode("{")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endToken, err = proc.Encode("}")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stringToken, err = proc.Encode("\"")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectKeyToken, err = proc.Encode("\"")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tabToken, err = proc.Encode("\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
spaceToken, err = proc.Encode(" ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newlineToken, err = proc.Encode("\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newlineSpace, err = proc.Encode(" \n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: figure out how to encode colon correctly
|
||||
colonToken, err = proc.Encode("\":")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("colonToken", colonToken)
|
||||
colonToken2, err = proc.Encode(":")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
commaToken, err = proc.Encode(",")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
commaToken2, err = proc.Encode("\",")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("commaToken2", commaToken2)
|
||||
commaToken3, err = proc.Encode("\",\"")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildStateMachine(proc model.TextProcessor) (*Node, error) {
|
||||
if err := initTokens(proc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startNode := NewNode(StateStart)
|
||||
objectNode := NewNode(StateInObject)
|
||||
objectKeyNode := NewNode(StateInObjectKey)
|
||||
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
|
||||
stringNode := NewNode(StateInString)
|
||||
intNode := NewNode(StateInInt)
|
||||
commaNode := NewNode(StateInComma)
|
||||
colonNode := NewNode(StateInColon)
|
||||
stringEndNode := NewNode(StateInStringEnd)
|
||||
endNode := NewNode(StateEnd)
|
||||
terminateNode := NewNode(StateTerminate)
|
||||
|
||||
sentinelToken := token([]int32{-1})
|
||||
intSentinelToken := token([]int32{-2})
|
||||
|
||||
startNode.TransitionEdges[objectNode] = []token{startToken}
|
||||
|
||||
objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
|
||||
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
||||
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
||||
|
||||
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
|
||||
objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
|
||||
// characterize end of object key
|
||||
objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken}
|
||||
|
||||
objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken}
|
||||
|
||||
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
|
||||
|
||||
intNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||
intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
|
||||
intNode.TransitionEdges[terminateNode] = []token{endToken}
|
||||
|
||||
commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken}
|
||||
|
||||
colonNode.TransitionEdges[stringNode] = []token{stringToken}
|
||||
colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||
|
||||
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
|
||||
stringNode.TransitionEdges[stringEndNode] = []token{stringToken}
|
||||
// "\""," Case
|
||||
stringNode.TransitionEdges[commaNode] = []token{commaToken2}
|
||||
|
||||
// "\"",\"" Case
|
||||
stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3}
|
||||
|
||||
stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
|
||||
stringEndNode.TransitionEdges[terminateNode] = []token{endToken}
|
||||
|
||||
endNode.TransitionEdges[terminateNode] = []token{endToken}
|
||||
|
||||
terminateNode.TransitionEdges[terminateNode] = []token{}
|
||||
return startNode, nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user