From a7c8cc06dac166778c09371de819bed30a4b6911 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 21 Jan 2025 17:48:12 -0800 Subject: [PATCH] json checkpoint --- sample/fast_json.go | 207 ++++++++++++++++++++++++++++++++++++++++ sample/state_machine.go | 158 ++++++++++++++++++++++++++++++ 2 files changed, 365 insertions(+) create mode 100644 sample/fast_json.go create mode 100644 sample/state_machine.go diff --git a/sample/fast_json.go b/sample/fast_json.go new file mode 100644 index 000000000..9601ff731 --- /dev/null +++ b/sample/fast_json.go @@ -0,0 +1,207 @@ +package sample + +import ( + "errors" + "fmt" + "math" + "slices" + + "github.com/ollama/ollama/model" +) + +type JSONState int + +const ( + StateStart JSONState = iota + StateInObject + StateInObjectKey + StateNewline + StateTab + StateSpace + StateInString + StateInInt + StateInFloat + StateInBool + StateInNull + StateInArray + StateInColon + StateInComma + StateInStringEnd + StateInObjectKeyEnd + StateTerminate + StateEnd +) + +func (s JSONState) String() string { + switch s { + case StateStart: + return "StateStart" + case StateInObject: + return "StateInObject" + case StateInObjectKey: + return "StateInObjectKey" + case StateInString: + return "StateInString" + case StateNewline: + return "StateNewline" + case StateTab: + return "StateTab" + case StateSpace: + return "StateSpace" + case StateInInt: + return "StateInInt" + case StateInFloat: + return "StateInFloat" + case StateInColon: + return "StateInColon" + case StateInBool: + return "StateInBool" + case StateInNull: + return "StateInNull" + case StateInArray: + return "StateInArray" + case StateEnd: + return "StateEnd" + case StateInComma: + return "StateInComma" + case StateInObjectKeyEnd: + return "StateInObjectKeyEnd" + case StateTerminate: + return "StateTerminate" + case StateInStringEnd: + return "StateInStringEnd" + default: + return fmt.Sprintf("Unknown state: %d", s) + } +} + +type JSONSampler struct { + curNode *Node + proc model.TextProcessor + stack []*Node +} + +func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { + // fmt.Println("Creating new JSON sampler") + startNode, err := buildStateMachine(proc) + if err != nil { + return nil, err + } + js := &JSONSampler{ + curNode: startNode, + proc: proc, + } + + return js, nil +} + +func (s *JSONSampler) UpdateState(tokenSlice []int32) error { + // fmt.Printf("Updating state with token: %v\n", tokenSlice) + // fmt.Printf("Current state: %s\n", s.curNode.State) + + // fmt.Println("tokenSlice", tokenSlice) + // todo: account for strings here + for node, edge := range s.curNode.TransitionEdges { + for _, validToken := range edge { + if slices.Equal(tokenSlice, validToken) { + s.curNode = node + // fmt.Printf("Transitioned to state: %s\n", node.State) + return nil + } + } + } + for node, edge := range s.curNode.TransitionEdges { + for _, validToken := range edge { + if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 { + s.curNode = node + // fmt.Printf("Accepting any token, staying in state: %s\n", node.State) + return nil + } + } + } + fmt.Println("invalid token ", tokenSlice) + return errors.New("invalid token") +} + +func (s *JSONSampler) Sample(logits []float64) ([]float64, error) { + fmt.Printf("Sampling in state: %s\n", s.curNode.State) + var err error + + switch s.curNode.State { + case StateTerminate: + for i := range logits { + if s.proc.Is(uint32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = math.NaN() + } + } + return logits, nil + + case StateInInt: + validStates := []int32{} + minus, err := s.proc.Encode("-") + if err != nil { + return nil, err + } + digits := make([][]int32, 10) + for i := 0; i < 10; i++ { + digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i)) + if err != nil { + return nil, err + } + } + // Allow "-" and digits 0-9 at start + for i := range logits { + for _, d := range digits { + if len(d) == 1 && int32(i) == d[0] { + validStates = append(validStates, int32(i)) + } + } + if len(minus) == 1 && int32(i) == minus[0] { + validStates = append(validStates, int32(i)) + } + } + return logits, nil + + default: + validStates := getValidStates(s.curNode) + logits, err = s.maskLogits(logits, validStates) + if err != nil { + return nil, err + } + return logits, nil + } +} + +func getValidStates(node *Node) []int32 { + validStates := []int32{} + for _, edge := range node.TransitionEdges { + for _, token := range edge { + validStates = append(validStates, token...) + } + } + return validStates +} + +func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) { + // fmt.Printf("Masking logits with valid states: %v\n", validStates) + for i := range logits { + isValid := false + for _, token := range validStates { + if token == -1 { + // fmt.Println("Found sentinel token, returning unmasked logits") + return logits, nil + } + if i == int(token) { + // fmt.Printf("Found valid token: %d\n", token) + isValid = true + break + } + } + if !isValid { + logits[i] = math.NaN() + } + } + return logits, nil +} diff --git a/sample/state_machine.go b/sample/state_machine.go new file mode 100644 index 000000000..c85b09f1e --- /dev/null +++ b/sample/state_machine.go @@ -0,0 +1,158 @@ +package sample + +import ( + "fmt" + + "github.com/ollama/ollama/model" +) + +type token []int32 + +type Node struct { + State JSONState + TransitionEdges map[*Node][]token +} + +func NewNode(state JSONState) *Node { + return &Node{ + State: state, + TransitionEdges: make(map[*Node][]token), + } +} + +var ( + startToken token + endToken token + stringToken token + objectKeyToken token + tabToken token + spaceToken token + newlineToken token + newlineSpace token + commaToken token + commaToken2 token + commaToken3 token + colonToken token + colonToken2 token +) + +func initTokens(proc model.TextProcessor) error { + var err error + startToken, err = proc.Encode("{") + if err != nil { + return err + } + endToken, err = proc.Encode("}") + if err != nil { + return err + } + stringToken, err = proc.Encode("\"") + if err != nil { + return err + } + objectKeyToken, err = proc.Encode("\"") + if err != nil { + return err + } + tabToken, err = proc.Encode("\t") + if err != nil { + return err + } + spaceToken, err = proc.Encode(" ") + if err != nil { + return err + } + newlineToken, err = proc.Encode("\n") + if err != nil { + return err + } + newlineSpace, err = proc.Encode(" \n") + if err != nil { + return err + } + // TODO: figure out how to encode colon correctly + colonToken, err = proc.Encode("\":") + if err != nil { + return err + } + fmt.Println("colonToken", colonToken) + colonToken2, err = proc.Encode(":") + if err != nil { + return err + } + commaToken, err = proc.Encode(",") + if err != nil { + return err + } + commaToken2, err = proc.Encode("\",") + if err != nil { + return err + } + fmt.Println("commaToken2", commaToken2) + commaToken3, err = proc.Encode("\",\"") + if err != nil { + return err + } + return nil +} + +func buildStateMachine(proc model.TextProcessor) (*Node, error) { + if err := initTokens(proc); err != nil { + return nil, err + } + + startNode := NewNode(StateStart) + objectNode := NewNode(StateInObject) + objectKeyNode := NewNode(StateInObjectKey) + objectKeyEndNode := NewNode(StateInObjectKeyEnd) + stringNode := NewNode(StateInString) + intNode := NewNode(StateInInt) + commaNode := NewNode(StateInComma) + colonNode := NewNode(StateInColon) + stringEndNode := NewNode(StateInStringEnd) + endNode := NewNode(StateEnd) + terminateNode := NewNode(StateTerminate) + + sentinelToken := token([]int32{-1}) + intSentinelToken := token([]int32{-2}) + + startNode.TransitionEdges[objectNode] = []token{startToken} + + objectNode.TransitionEdges[objectKeyNode] = []token{stringToken} + // objectNode.TransitionEdges[objectNode] = []token{newlineToken} + // objectNode.TransitionEdges[objectNode] = []token{spaceToken} + + objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken} + objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2} + // characterize end of object key + objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken} + + objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken} + + // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken} + + intNode.TransitionEdges[intNode] = []token{intSentinelToken} + intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} + intNode.TransitionEdges[terminateNode] = []token{endToken} + + commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken} + + colonNode.TransitionEdges[stringNode] = []token{stringToken} + colonNode.TransitionEdges[intNode] = []token{intSentinelToken} + + stringNode.TransitionEdges[stringNode] = []token{sentinelToken} + stringNode.TransitionEdges[stringEndNode] = []token{stringToken} + // "\""," Case + stringNode.TransitionEdges[commaNode] = []token{commaToken2} + + // "\"",\"" Case + stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3} + + stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} + stringEndNode.TransitionEdges[terminateNode] = []token{endToken} + + endNode.TransitionEdges[terminateNode] = []token{endToken} + + terminateNode.TransitionEdges[terminateNode] = []token{} + return startNode, nil +}