checkpoint

This commit is contained in:
ParthSareen 2025-01-23 09:46:14 -08:00
parent a7c8cc06da
commit 6ba557f25b
2 changed files with 198 additions and 56 deletions

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"slices"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
@ -76,9 +75,10 @@ func (s JSONState) String() string {
} }
type JSONSampler struct { type JSONSampler struct {
curNode *Node curNode *Node
proc model.TextProcessor proc model.TextProcessor
stack []*Node stack []*Node
bracketCounter int
} }
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
return nil, err return nil, err
} }
js := &JSONSampler{ js := &JSONSampler{
curNode: startNode, curNode: startNode,
proc: proc, proc: proc,
stack: []*Node{},
bracketCounter: 0,
} }
return js, nil return js, nil
} }
func isTokenSubset(subset, superset []int32) bool {
freq1 := make(map[int32]int)
freq2 := make(map[int32]int)
for _, v := range subset {
freq1[v]++
}
for _, v := range superset {
freq2[v]++
}
isSubset := true
for k, count1 := range freq1 {
count2 := freq2[k]
if count1 > count2 {
isSubset = false
break
}
}
return isSubset
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) error { func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Printf("Updating state with token: %v\n", tokenSlice) // fmt.Printf("Updating state with token: %v\n", tokenSlice)
// fmt.Printf("Current state: %s\n", s.curNode.State) // fmt.Printf("Current state: %s\n", s.curNode.State)
// fmt.Println("tokenSlice", tokenSlice) // fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here // todo: account for strings here
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
if err != nil {
return err
}
// only move to terminate state if stack is empty
if s.curNode.State == StateEnd {
fmt.Println("debug: node.State", s.curNode.State)
if len(s.stack) > 0 {
s.stack = s.stack[:len(s.stack)-1]
fmt.Println("popped and cur state", s.curNode.State)
return nil
}
return nil
}
for node, edge := range s.curNode.TransitionEdges { for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge { for _, validToken := range edge {
if slices.Equal(tokenSlice, validToken) { if isTokenSubset(tokenSlice, validToken) {
s.curNode = node s.curNode = node
for _, token := range objectTokens {
if isTokenSubset(tokenSlice, token) {
fmt.Println("Appending to stack", s.curNode.State)
s.stack = append(s.stack, s.curNode)
}
}
// fmt.Printf("Transitioned to state: %s\n", node.State) // fmt.Printf("Transitioned to state: %s\n", node.State)
return nil return nil
} }
@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
} }
} }
fmt.Println("invalid token ", tokenSlice) fmt.Println("invalid token ", tokenSlice)
dec, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println("decoded token ", dec)
return errors.New("invalid token") return errors.New("invalid token")
} }
@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
} }
return logits, nil return logits, nil
case StateInString:
penalizeNewlineVariants := []string{"\n", " \"\n"}
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
if err != nil {
return nil, err
}
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
if err != nil {
return nil, err
}
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
default: default:
validStates := getValidStates(s.curNode) validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates) logits, err = s.maskLogits(logits, validStates)
@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float
} }
return logits, nil return logits, nil
} }
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
for i := range logits {
for _, token := range tokensToMask {
for _, chunked := range token {
if int(chunked) == i {
logits[i] = math.NaN()
}
}
}
}
return logits, nil
}

View File

@ -21,39 +21,78 @@ func NewNode(state JSONState) *Node {
} }
var ( var (
startToken token // startToken token
endToken token startTokenVariants []token
stringToken token // endToken token
objectKeyToken token // stringToken token
tabToken token // objectKeyToken token
spaceToken token tabToken token
newlineToken token spaceToken token
newlineSpace token newlineToken token
commaToken token newlineSpace token
commaToken2 token // commaToken token
commaToken3 token // commaToken2 token
colonToken token // commaToken3 token
colonToken2 token // colonToken token
// colonToken2 token
colonTokenVariants []token
commaTokenVariants []token
stringTokenVariants []token
endTokenVariants []token
objectKeyTokenVariants []token
objKeyToColonVariants []token
stringToObjectKeyVariants []token
stringToCommaVariants []token
stringToObjectVariants []token
stringEndToObjectEndVariants []token
stringEndToCommaVariants []token
) )
func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
var allTokens token
for _, variant := range variants {
if t, err := proc.Encode(variant); err == nil {
allTokens = append(allTokens, t...)
}
}
if len(allTokens) == 0 {
return nil, fmt.Errorf("no valid tokens found for variants")
}
return []token{allTokens}, nil
}
func initTokens(proc model.TextProcessor) error { func initTokens(proc model.TextProcessor) error {
var err error var err error
startToken, err = proc.Encode("{")
s, err := proc.Decode([]int32{761})
fmt.Printf("761 decoded %q\n", s)
// Compute start token variants
startVariants := []string{"{", " {", "{\n", " {\n"}
startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
if err != nil { if err != nil {
return err return err
} }
endToken, err = proc.Encode("}") // Compute end token variants
endVariants := []string{"}", " }", "}\n", " }\n"}
endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
if err != nil { if err != nil {
return err return err
} }
stringToken, err = proc.Encode("\"")
// Compute string token variants
// TODO: removed \n
stringVariants := []string{"\"", " \""}
stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
if err != nil { if err != nil {
return err return err
} }
objectKeyToken, err = proc.Encode("\"") stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
if err != nil { if err != nil {
return err return err
} }
// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
objectKeyTokenVariants = stringTokenVariants
// Compute whitespace tokens
tabToken, err = proc.Encode("\t") tabToken, err = proc.Encode("\t")
if err != nil { if err != nil {
return err return err
@ -70,29 +109,35 @@ func initTokens(proc model.TextProcessor) error {
if err != nil { if err != nil {
return err return err
} }
// TODO: figure out how to encode colon correctly
colonToken, err = proc.Encode("\":") // Compute colon variants
colonVariants := []string{":"}
colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
if err != nil { if err != nil {
return err return err
} }
fmt.Println("colonToken", colonToken) objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
colonToken2, err = proc.Encode(":")
if err != nil { if err != nil {
return err return err
} }
commaToken, err = proc.Encode(",")
// Compute comma variants
commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
if err != nil { if err != nil {
return err return err
} }
commaToken2, err = proc.Encode("\",") fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
if err != nil { stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
return err
}
fmt.Println("commaToken2", commaToken2)
commaToken3, err = proc.Encode("\",\"")
if err != nil { if err != nil {
return err return err
} }
stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
return nil return nil
} }
@ -106,7 +151,7 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
objectKeyNode := NewNode(StateInObjectKey) objectKeyNode := NewNode(StateInObjectKey)
objectKeyEndNode := NewNode(StateInObjectKeyEnd) objectKeyEndNode := NewNode(StateInObjectKeyEnd)
stringNode := NewNode(StateInString) stringNode := NewNode(StateInString)
intNode := NewNode(StateInInt) // intNode := NewNode(StateInInt)
commaNode := NewNode(StateInComma) commaNode := NewNode(StateInComma)
colonNode := NewNode(StateInColon) colonNode := NewNode(StateInColon)
stringEndNode := NewNode(StateInStringEnd) stringEndNode := NewNode(StateInStringEnd)
@ -114,44 +159,59 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
terminateNode := NewNode(StateTerminate) terminateNode := NewNode(StateTerminate)
sentinelToken := token([]int32{-1}) sentinelToken := token([]int32{-1})
intSentinelToken := token([]int32{-2}) // intSentinelToken := token([]int32{-2})
startNode.TransitionEdges[objectNode] = []token{startToken} // TODO: cleanup connections of rules
startNode.TransitionEdges[objectNode] = startTokenVariants
objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
objectNode.TransitionEdges[objectNode] = []token{newlineToken}
objectNode.TransitionEdges[objectNode] = []token{spaceToken}
objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
// objectNode.TransitionEdges[objectNode] = []token{newlineToken} // objectNode.TransitionEdges[objectNode] = []token{newlineToken}
// objectNode.TransitionEdges[objectNode] = []token{spaceToken} // objectNode.TransitionEdges[objectNode] = []token{spaceToken}
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken} objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
// characterize end of object key // characterize end of object key
objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken} objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken} // TODO: enable this - key -> object
// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken} // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
intNode.TransitionEdges[intNode] = []token{intSentinelToken} // intNode.TransitionEdges[intNode] = []token{intSentinelToken}
intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} // intNode.TransitionEdges[commaNode] = commaTokenVariants
intNode.TransitionEdges[terminateNode] = []token{endToken} // TODO: handle
// intNode.TransitionEdges[terminateNode] = endTokenVariants
commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken} commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
// commaNode.TransitionEdges[objectNode] = startTokenVariants
colonNode.TransitionEdges[stringNode] = []token{stringToken} colonNode.TransitionEdges[stringNode] = stringTokenVariants
colonNode.TransitionEdges[intNode] = []token{intSentinelToken} //TODO: enable
// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
colonNode.TransitionEdges[objectNode] = startTokenVariants
stringNode.TransitionEdges[stringNode] = []token{sentinelToken} stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
stringNode.TransitionEdges[stringEndNode] = []token{stringToken} stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
// "\""," Case // TODO: "\""," Case not accounted for
stringNode.TransitionEdges[commaNode] = []token{commaToken2} stringNode.TransitionEdges[commaNode] = stringToCommaVariants
// "\"",\"" Case // TODO: "\"",\"" Case not accounted for
stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3} stringNode.TransitionEdges[objectNode] = stringToObjectVariants
stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2} stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
stringEndNode.TransitionEdges[terminateNode] = []token{endToken} stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
endNode.TransitionEdges[terminateNode] = []token{endToken} // Should be obj end
// TODO: handle
endNode.TransitionEdges[terminateNode] = []token{}
endNode.TransitionEdges[commaNode] = commaTokenVariants
terminateNode.TransitionEdges[terminateNode] = []token{} terminateNode.TransitionEdges[terminateNode] = []token{}
return startNode, nil return startNode, nil