From 6ba557f25b4e42bc92b7510122e0125991581e7b Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 23 Jan 2025 09:46:14 -0800 Subject: [PATCH] checkpoint --- sample/fast_json.go | 96 ++++++++++++++++++++++-- sample/state_machine.go | 158 +++++++++++++++++++++++++++------------- 2 files changed, 198 insertions(+), 56 deletions(-) diff --git a/sample/fast_json.go b/sample/fast_json.go index 9601ff731..886efb1f7 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "math" - "slices" "github.com/ollama/ollama/model" ) @@ -76,9 +75,10 @@ func (s JSONState) String() string { } type JSONSampler struct { - curNode *Node - proc model.TextProcessor - stack []*Node + curNode *Node + proc model.TextProcessor + stack []*Node + bracketCounter int } func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { @@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { return nil, err } js := &JSONSampler{ - curNode: startNode, - proc: proc, + curNode: startNode, + proc: proc, + stack: []*Node{}, + bracketCounter: 0, } return js, nil } +func isTokenSubset(subset, superset []int32) bool { + freq1 := make(map[int32]int) + freq2 := make(map[int32]int) + + for _, v := range subset { + freq1[v]++ + } + for _, v := range superset { + freq2[v]++ + } + isSubset := true + for k, count1 := range freq1 { + count2 := freq2[k] + if count1 > count2 { + isSubset = false + break + } + } + return isSubset +} + func (s *JSONSampler) UpdateState(tokenSlice []int32) error { // fmt.Printf("Updating state with token: %v\n", tokenSlice) // fmt.Printf("Current state: %s\n", s.curNode.State) // fmt.Println("tokenSlice", tokenSlice) // todo: account for strings here + objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc) + if err != nil { + return err + } + + // only move to terminate state if stack is empty + if s.curNode.State == StateEnd { + fmt.Println("debug: node.State", s.curNode.State) + if len(s.stack) > 0 { + s.stack = s.stack[:len(s.stack)-1] + fmt.Println("popped and cur state", s.curNode.State) + return nil + } + return nil + } + for node, edge := range s.curNode.TransitionEdges { for _, validToken := range edge { - if slices.Equal(tokenSlice, validToken) { + if isTokenSubset(tokenSlice, validToken) { s.curNode = node + for _, token := range objectTokens { + if isTokenSubset(tokenSlice, token) { + fmt.Println("Appending to stack", s.curNode.State) + s.stack = append(s.stack, s.curNode) + } + } // fmt.Printf("Transitioned to state: %s\n", node.State) return nil } @@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error { } } fmt.Println("invalid token ", tokenSlice) + dec, err := s.proc.Decode(tokenSlice) + if err != nil { + return err + } + fmt.Println("decoded token ", dec) return errors.New("invalid token") } @@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) { } return logits, nil + case StateInString: + penalizeNewlineVariants := []string{"\n", " \"\n"} + penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc) + if err != nil { + return nil, err + } + penalizeNewlineToks = append(penalizeNewlineToks, []int32{702}) + logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks) + if err != nil { + return nil, err + } + validStates := getValidStates(s.curNode) + logits, err = s.maskLogits(logits, validStates) + if err != nil { + return nil, err + } + return logits, nil + default: validStates := getValidStates(s.curNode) logits, err = s.maskLogits(logits, validStates) @@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float } return logits, nil } + +func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) { + // fmt.Printf("Masking specific logits: %v\n", tokensToMask) + for i := range logits { + for _, token := range tokensToMask { + for _, chunked := range token { + if int(chunked) == i { + logits[i] = math.NaN() + } + } + } + } + return logits, nil +} diff --git a/sample/state_machine.go b/sample/state_machine.go index c85b09f1e..a5e8779fe 100644 --- a/sample/state_machine.go +++ b/sample/state_machine.go @@ -21,39 +21,78 @@ func NewNode(state JSONState) *Node { } var ( - startToken token - endToken token - stringToken token - objectKeyToken token - tabToken token - spaceToken token - newlineToken token - newlineSpace token - commaToken token - commaToken2 token - commaToken3 token - colonToken token - colonToken2 token + // startToken token + startTokenVariants []token + // endToken token + // stringToken token + // objectKeyToken token + tabToken token + spaceToken token + newlineToken token + newlineSpace token + // commaToken token + // commaToken2 token + // commaToken3 token + // colonToken token + // colonToken2 token + colonTokenVariants []token + commaTokenVariants []token + stringTokenVariants []token + endTokenVariants []token + objectKeyTokenVariants []token + objKeyToColonVariants []token + stringToObjectKeyVariants []token + stringToCommaVariants []token + stringToObjectVariants []token + stringEndToObjectEndVariants []token + stringEndToCommaVariants []token ) +func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) { + var allTokens token + for _, variant := range variants { + if t, err := proc.Encode(variant); err == nil { + allTokens = append(allTokens, t...) + } + } + if len(allTokens) == 0 { + return nil, fmt.Errorf("no valid tokens found for variants") + } + return []token{allTokens}, nil +} func initTokens(proc model.TextProcessor) error { var err error - startToken, err = proc.Encode("{") + + s, err := proc.Decode([]int32{761}) + fmt.Printf("761 decoded %q\n", s) + + // Compute start token variants + startVariants := []string{"{", " {", "{\n", " {\n"} + startTokenVariants, err = ComputeTokenVariants(startVariants, proc) if err != nil { return err } - endToken, err = proc.Encode("}") + // Compute end token variants + endVariants := []string{"}", " }", "}\n", " }\n"} + endTokenVariants, err = ComputeTokenVariants(endVariants, proc) if err != nil { return err } - stringToken, err = proc.Encode("\"") + + // Compute string token variants + // TODO: removed \n + stringVariants := []string{"\"", " \""} + stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc) if err != nil { return err } - objectKeyToken, err = proc.Encode("\"") + stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc) if err != nil { return err } + // objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]} + objectKeyTokenVariants = stringTokenVariants + // Compute whitespace tokens tabToken, err = proc.Encode("\t") if err != nil { return err @@ -70,29 +109,35 @@ func initTokens(proc model.TextProcessor) error { if err != nil { return err } - // TODO: figure out how to encode colon correctly - colonToken, err = proc.Encode("\":") + + // Compute colon variants + colonVariants := []string{":"} + colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc) if err != nil { return err } - fmt.Println("colonToken", colonToken) - colonToken2, err = proc.Encode(":") + objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc) if err != nil { return err } - commaToken, err = proc.Encode(",") + + // Compute comma variants + commaVariants := []string{",", " ,", ",\n", "\",", "\", "} + commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc) if err != nil { return err } - commaToken2, err = proc.Encode("\",") - if err != nil { - return err - } - fmt.Println("commaToken2", commaToken2) - commaToken3, err = proc.Encode("\",\"") + fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants) + stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc) if err != nil { return err } + + stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc) + stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc) + stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc) + stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc) + return nil } @@ -106,7 +151,7 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) { objectKeyNode := NewNode(StateInObjectKey) objectKeyEndNode := NewNode(StateInObjectKeyEnd) stringNode := NewNode(StateInString) - intNode := NewNode(StateInInt) + // intNode := NewNode(StateInInt) commaNode := NewNode(StateInComma) colonNode := NewNode(StateInColon) stringEndNode := NewNode(StateInStringEnd) @@ -114,44 +159,59 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) { terminateNode := NewNode(StateTerminate) sentinelToken := token([]int32{-1}) - intSentinelToken := token([]int32{-2}) + // intSentinelToken := token([]int32{-2}) - startNode.TransitionEdges[objectNode] = []token{startToken} + // TODO: cleanup connections of rules + startNode.TransitionEdges[objectNode] = startTokenVariants + + objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants + objectNode.TransitionEdges[objectNode] = []token{newlineToken} + objectNode.TransitionEdges[objectNode] = []token{spaceToken} - objectNode.TransitionEdges[objectKeyNode] = []token{stringToken} // objectNode.TransitionEdges[objectNode] = []token{newlineToken} // objectNode.TransitionEdges[objectNode] = []token{spaceToken} objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken} - objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2} // characterize end of object key - objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken} + objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants + objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants - objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken} + // TODO: enable this - key -> object + // objectKeyNode.TransitionEdges[objectNode] = startTokenVariants // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken} - intNode.TransitionEdges[intNode] = []token{intSentinelToken} - intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} - intNode.TransitionEdges[terminateNode] = []token{endToken} + // intNode.TransitionEdges[intNode] = []token{intSentinelToken} + // intNode.TransitionEdges[commaNode] = commaTokenVariants + // TODO: handle + // intNode.TransitionEdges[terminateNode] = endTokenVariants - commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken} + commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants + // commaNode.TransitionEdges[objectNode] = startTokenVariants - colonNode.TransitionEdges[stringNode] = []token{stringToken} - colonNode.TransitionEdges[intNode] = []token{intSentinelToken} + colonNode.TransitionEdges[stringNode] = stringTokenVariants + //TODO: enable + // colonNode.TransitionEdges[intNode] = []token{intSentinelToken} + colonNode.TransitionEdges[objectNode] = startTokenVariants stringNode.TransitionEdges[stringNode] = []token{sentinelToken} - stringNode.TransitionEdges[stringEndNode] = []token{stringToken} - // "\""," Case - stringNode.TransitionEdges[commaNode] = []token{commaToken2} + stringNode.TransitionEdges[stringEndNode] = stringTokenVariants + // TODO: "\""," Case not accounted for + stringNode.TransitionEdges[commaNode] = stringToCommaVariants - // "\"",\"" Case - stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3} + // TODO: "\"",\"" Case not accounted for + stringNode.TransitionEdges[objectNode] = stringToObjectVariants - stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} - stringEndNode.TransitionEdges[terminateNode] = []token{endToken} + stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants + stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants + stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants + // stringEndNode.TransitionEdges[terminateNode] = endTokenVariants - endNode.TransitionEdges[terminateNode] = []token{endToken} + // Should be obj end + // TODO: handle + endNode.TransitionEdges[terminateNode] = []token{} + + endNode.TransitionEdges[commaNode] = commaTokenVariants terminateNode.TransitionEdges[terminateNode] = []token{} return startNode, nil