checkpoint
This commit is contained in:
parent
a7c8cc06da
commit
6ba557f25b
@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
@ -76,9 +75,10 @@ func (s JSONState) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type JSONSampler struct {
|
type JSONSampler struct {
|
||||||
curNode *Node
|
curNode *Node
|
||||||
proc model.TextProcessor
|
proc model.TextProcessor
|
||||||
stack []*Node
|
stack []*Node
|
||||||
|
bracketCounter int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
||||||
@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
js := &JSONSampler{
|
js := &JSONSampler{
|
||||||
curNode: startNode,
|
curNode: startNode,
|
||||||
proc: proc,
|
proc: proc,
|
||||||
|
stack: []*Node{},
|
||||||
|
bracketCounter: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
return js, nil
|
return js, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTokenSubset(subset, superset []int32) bool {
|
||||||
|
freq1 := make(map[int32]int)
|
||||||
|
freq2 := make(map[int32]int)
|
||||||
|
|
||||||
|
for _, v := range subset {
|
||||||
|
freq1[v]++
|
||||||
|
}
|
||||||
|
for _, v := range superset {
|
||||||
|
freq2[v]++
|
||||||
|
}
|
||||||
|
isSubset := true
|
||||||
|
for k, count1 := range freq1 {
|
||||||
|
count2 := freq2[k]
|
||||||
|
if count1 > count2 {
|
||||||
|
isSubset = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isSubset
|
||||||
|
}
|
||||||
|
|
||||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||||
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
||||||
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
||||||
|
|
||||||
// fmt.Println("tokenSlice", tokenSlice)
|
// fmt.Println("tokenSlice", tokenSlice)
|
||||||
// todo: account for strings here
|
// todo: account for strings here
|
||||||
|
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// only move to terminate state if stack is empty
|
||||||
|
if s.curNode.State == StateEnd {
|
||||||
|
fmt.Println("debug: node.State", s.curNode.State)
|
||||||
|
if len(s.stack) > 0 {
|
||||||
|
s.stack = s.stack[:len(s.stack)-1]
|
||||||
|
fmt.Println("popped and cur state", s.curNode.State)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for node, edge := range s.curNode.TransitionEdges {
|
for node, edge := range s.curNode.TransitionEdges {
|
||||||
for _, validToken := range edge {
|
for _, validToken := range edge {
|
||||||
if slices.Equal(tokenSlice, validToken) {
|
if isTokenSubset(tokenSlice, validToken) {
|
||||||
s.curNode = node
|
s.curNode = node
|
||||||
|
for _, token := range objectTokens {
|
||||||
|
if isTokenSubset(tokenSlice, token) {
|
||||||
|
fmt.Println("Appending to stack", s.curNode.State)
|
||||||
|
s.stack = append(s.stack, s.curNode)
|
||||||
|
}
|
||||||
|
}
|
||||||
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
fmt.Println("invalid token ", tokenSlice)
|
fmt.Println("invalid token ", tokenSlice)
|
||||||
|
dec, err := s.proc.Decode(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("decoded token ", dec)
|
||||||
return errors.New("invalid token")
|
return errors.New("invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
return logits, nil
|
return logits, nil
|
||||||
|
|
||||||
|
case StateInString:
|
||||||
|
penalizeNewlineVariants := []string{"\n", " \"\n"}
|
||||||
|
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
|
||||||
|
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
validStates := getValidStates(s.curNode)
|
||||||
|
logits, err = s.maskLogits(logits, validStates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
validStates := getValidStates(s.curNode)
|
validStates := getValidStates(s.curNode)
|
||||||
logits, err = s.maskLogits(logits, validStates)
|
logits, err = s.maskLogits(logits, validStates)
|
||||||
@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float
|
|||||||
}
|
}
|
||||||
return logits, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
|
||||||
|
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
|
||||||
|
for i := range logits {
|
||||||
|
for _, token := range tokensToMask {
|
||||||
|
for _, chunked := range token {
|
||||||
|
if int(chunked) == i {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
@ -21,39 +21,78 @@ func NewNode(state JSONState) *Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
startToken token
|
// startToken token
|
||||||
endToken token
|
startTokenVariants []token
|
||||||
stringToken token
|
// endToken token
|
||||||
objectKeyToken token
|
// stringToken token
|
||||||
tabToken token
|
// objectKeyToken token
|
||||||
spaceToken token
|
tabToken token
|
||||||
newlineToken token
|
spaceToken token
|
||||||
newlineSpace token
|
newlineToken token
|
||||||
commaToken token
|
newlineSpace token
|
||||||
commaToken2 token
|
// commaToken token
|
||||||
commaToken3 token
|
// commaToken2 token
|
||||||
colonToken token
|
// commaToken3 token
|
||||||
colonToken2 token
|
// colonToken token
|
||||||
|
// colonToken2 token
|
||||||
|
colonTokenVariants []token
|
||||||
|
commaTokenVariants []token
|
||||||
|
stringTokenVariants []token
|
||||||
|
endTokenVariants []token
|
||||||
|
objectKeyTokenVariants []token
|
||||||
|
objKeyToColonVariants []token
|
||||||
|
stringToObjectKeyVariants []token
|
||||||
|
stringToCommaVariants []token
|
||||||
|
stringToObjectVariants []token
|
||||||
|
stringEndToObjectEndVariants []token
|
||||||
|
stringEndToCommaVariants []token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
|
||||||
|
var allTokens token
|
||||||
|
for _, variant := range variants {
|
||||||
|
if t, err := proc.Encode(variant); err == nil {
|
||||||
|
allTokens = append(allTokens, t...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(allTokens) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid tokens found for variants")
|
||||||
|
}
|
||||||
|
return []token{allTokens}, nil
|
||||||
|
}
|
||||||
func initTokens(proc model.TextProcessor) error {
|
func initTokens(proc model.TextProcessor) error {
|
||||||
var err error
|
var err error
|
||||||
startToken, err = proc.Encode("{")
|
|
||||||
|
s, err := proc.Decode([]int32{761})
|
||||||
|
fmt.Printf("761 decoded %q\n", s)
|
||||||
|
|
||||||
|
// Compute start token variants
|
||||||
|
startVariants := []string{"{", " {", "{\n", " {\n"}
|
||||||
|
startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
endToken, err = proc.Encode("}")
|
// Compute end token variants
|
||||||
|
endVariants := []string{"}", " }", "}\n", " }\n"}
|
||||||
|
endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stringToken, err = proc.Encode("\"")
|
|
||||||
|
// Compute string token variants
|
||||||
|
// TODO: removed \n
|
||||||
|
stringVariants := []string{"\"", " \""}
|
||||||
|
stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
objectKeyToken, err = proc.Encode("\"")
|
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
|
||||||
|
objectKeyTokenVariants = stringTokenVariants
|
||||||
|
// Compute whitespace tokens
|
||||||
tabToken, err = proc.Encode("\t")
|
tabToken, err = proc.Encode("\t")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -70,29 +109,35 @@ func initTokens(proc model.TextProcessor) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: figure out how to encode colon correctly
|
|
||||||
colonToken, err = proc.Encode("\":")
|
// Compute colon variants
|
||||||
|
colonVariants := []string{":"}
|
||||||
|
colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println("colonToken", colonToken)
|
objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
|
||||||
colonToken2, err = proc.Encode(":")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
commaToken, err = proc.Encode(",")
|
|
||||||
|
// Compute comma variants
|
||||||
|
commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
|
||||||
|
commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
commaToken2, err = proc.Encode("\",")
|
fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
|
||||||
if err != nil {
|
stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Println("commaToken2", commaToken2)
|
|
||||||
commaToken3, err = proc.Encode("\",\"")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
|
||||||
|
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
|
||||||
|
stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
|
||||||
|
stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +151,7 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
|
|||||||
objectKeyNode := NewNode(StateInObjectKey)
|
objectKeyNode := NewNode(StateInObjectKey)
|
||||||
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
|
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
|
||||||
stringNode := NewNode(StateInString)
|
stringNode := NewNode(StateInString)
|
||||||
intNode := NewNode(StateInInt)
|
// intNode := NewNode(StateInInt)
|
||||||
commaNode := NewNode(StateInComma)
|
commaNode := NewNode(StateInComma)
|
||||||
colonNode := NewNode(StateInColon)
|
colonNode := NewNode(StateInColon)
|
||||||
stringEndNode := NewNode(StateInStringEnd)
|
stringEndNode := NewNode(StateInStringEnd)
|
||||||
@ -114,44 +159,59 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
|
|||||||
terminateNode := NewNode(StateTerminate)
|
terminateNode := NewNode(StateTerminate)
|
||||||
|
|
||||||
sentinelToken := token([]int32{-1})
|
sentinelToken := token([]int32{-1})
|
||||||
intSentinelToken := token([]int32{-2})
|
// intSentinelToken := token([]int32{-2})
|
||||||
|
|
||||||
startNode.TransitionEdges[objectNode] = []token{startToken}
|
// TODO: cleanup connections of rules
|
||||||
|
startNode.TransitionEdges[objectNode] = startTokenVariants
|
||||||
|
|
||||||
|
objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
||||||
|
objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
||||||
|
objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
||||||
|
|
||||||
objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
|
|
||||||
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
||||||
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
||||||
|
|
||||||
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
|
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
|
||||||
objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
|
|
||||||
// characterize end of object key
|
// characterize end of object key
|
||||||
objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken}
|
objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
|
||||||
|
objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
|
||||||
|
|
||||||
objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken}
|
// TODO: enable this - key -> object
|
||||||
|
// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
|
||||||
|
|
||||||
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
|
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
|
||||||
|
|
||||||
intNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||||
intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
|
// intNode.TransitionEdges[commaNode] = commaTokenVariants
|
||||||
intNode.TransitionEdges[terminateNode] = []token{endToken}
|
// TODO: handle
|
||||||
|
// intNode.TransitionEdges[terminateNode] = endTokenVariants
|
||||||
|
|
||||||
commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken}
|
commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
||||||
|
// commaNode.TransitionEdges[objectNode] = startTokenVariants
|
||||||
|
|
||||||
colonNode.TransitionEdges[stringNode] = []token{stringToken}
|
colonNode.TransitionEdges[stringNode] = stringTokenVariants
|
||||||
colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
//TODO: enable
|
||||||
|
// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||||
|
colonNode.TransitionEdges[objectNode] = startTokenVariants
|
||||||
|
|
||||||
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
|
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
|
||||||
stringNode.TransitionEdges[stringEndNode] = []token{stringToken}
|
stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
|
||||||
// "\""," Case
|
// TODO: "\""," Case not accounted for
|
||||||
stringNode.TransitionEdges[commaNode] = []token{commaToken2}
|
stringNode.TransitionEdges[commaNode] = stringToCommaVariants
|
||||||
|
|
||||||
// "\"",\"" Case
|
// TODO: "\"",\"" Case not accounted for
|
||||||
stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3}
|
stringNode.TransitionEdges[objectNode] = stringToObjectVariants
|
||||||
|
|
||||||
stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
|
stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
|
||||||
stringEndNode.TransitionEdges[terminateNode] = []token{endToken}
|
stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
|
||||||
|
stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
|
||||||
|
// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
|
||||||
|
|
||||||
endNode.TransitionEdges[terminateNode] = []token{endToken}
|
// Should be obj end
|
||||||
|
// TODO: handle
|
||||||
|
endNode.TransitionEdges[terminateNode] = []token{}
|
||||||
|
|
||||||
|
endNode.TransitionEdges[commaNode] = commaTokenVariants
|
||||||
|
|
||||||
terminateNode.TransitionEdges[terminateNode] = []token{}
|
terminateNode.TransitionEdges[terminateNode] = []token{}
|
||||||
return startNode, nil
|
return startNode, nil
|
||||||
|
Loading…
x
Reference in New Issue
Block a user