This commit is contained in:
ParthSareen 2025-01-27 16:33:55 -08:00
parent a2a73ce5e0
commit e93db4d20e
3 changed files with 214 additions and 29 deletions

View File

@ -22,12 +22,18 @@ const (
StateInFloat StateInFloat
StateInBool StateInBool
StateInNull StateInNull
StateInArray
StateInColon StateInColon
StateInComma StateInComma
StateInTab StateInTab
StateInSpace StateInSpace
StateInObjSpace
StateInList
StateInListComma
StateListEnd
StateInListEnd
StateInNewline StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd StateInStringEnd
StateInObjectKeyEnd StateInObjectKeyEnd
StateTerminate StateTerminate
@ -42,42 +48,54 @@ func (s JSONState) String() string {
return "StateInObject" return "StateInObject"
case StateInObjectKey: case StateInObjectKey:
return "StateInObjectKey" return "StateInObjectKey"
case StateInString:
return "StateInString"
case StateNewline: case StateNewline:
return "StateNewline" return "StateNewline"
case StateTab: case StateTab:
return "StateTab" return "StateTab"
case StateSpace: case StateSpace:
return "StateSpace" return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt: case StateInInt:
return "StateInInt" return "StateInInt"
case StateInFloat: case StateInFloat:
return "StateInFloat" return "StateInFloat"
case StateInColon:
return "StateInColon"
case StateInBool: case StateInBool:
return "StateInBool" return "StateInBool"
case StateInNull: case StateInNull:
return "StateInNull" return "StateInNull"
case StateInArray: case StateInColon:
return "StateInArray" return "StateInColon"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateInComma: case StateInComma:
return "StateInComma" return "StateInComma"
case StateInTab: case StateInTab:
return "StateInTab" return "StateInTab"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateInNewline:
return "StateInNewline"
case StateInSpace: case StateInSpace:
return "StateInSpace" return "StateInSpace"
case StateTerminate: case StateInObjSpace:
return "StateTerminate" return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateListEnd:
return "StateListEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd: case StateInStringEnd:
return "StateInStringEnd" return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
default: default:
return fmt.Sprintf("Unknown state: %d", s) return fmt.Sprintf("Unknown state: %d", s)
} }
@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 {
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) { func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
// fmt.Printf("Masking logits with valid states: %v\n", validStates) // fmt.Printf("Masking logits with valid states: %v\n", validStates)
// todo: this can prob be more efficient
for i := range logits { for i := range logits {
isValid := false isValid := false
for _, token := range validStates { for _, token := range validStates {

View File

@ -8,6 +8,15 @@ import (
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','} var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct { type PDANode struct {
State JSONState State JSONState
TransitionEdges map[rune]*PDANode TransitionEdges map[rune]*PDANode
@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
spaceNode := NewPDANode(StateInSpace) spaceNode := NewPDANode(StateInSpace)
stateToNodeMap[StateInSpace] = spaceNode stateToNodeMap[StateInSpace] = spaceNode
spaceObjNode := NewPDANode(StateInObjSpace)
stateToNodeMap[StateInObjSpace] = spaceObjNode
tabNode := NewPDANode(StateInTab) tabNode := NewPDANode(StateInTab)
stateToNodeMap[StateInTab] = tabNode stateToNodeMap[StateInTab] = tabNode
@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stringEndNode := NewPDANode(StateInStringEnd) stringEndNode := NewPDANode(StateInStringEnd)
stateToNodeMap[StateInStringEnd] = stringEndNode stateToNodeMap[StateInStringEnd] = stringEndNode
// terminateNode := NewNode(StateTerminate) listNode := NewPDANode(StateInList)
stateToNodeMap[StateInList] = listNode
listCommaNode := NewPDANode(StateInListComma)
stateToNodeMap[StateInListComma] = listCommaNode
listEndNode := NewPDANode(StateListEnd)
stateToNodeMap[StateListEnd] = listEndNode
numberNode := NewPDANode(StateInNumber)
stateToNodeMap[StateInNumber] = numberNode
boolNode := NewPDANode(StateInBool)
stateToNodeMap[StateInBool] = boolNode
nullNode := NewPDANode(StateInNull)
stateToNodeMap[StateInNull] = nullNode
// Defined with structured outputs only
intNode := NewPDANode(StateInInt)
stateToNodeMap[StateInInt] = intNode
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
// Connect nodes // Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token // TODO: if all are single tokens then this can just be connected instead of defining the token
@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
objNode.TransitionEdges['"'] = objKeyNode objNode.TransitionEdges['"'] = objKeyNode
objNode.TransitionEdges['\n'] = newlineNode objNode.TransitionEdges['\n'] = newlineNode
// objNode.TransitionEdges['\t'] = tabNode
newlineNode.TransitionEdges['"'] = objKeyNode newlineNode.TransitionEdges['"'] = objKeyNode
newlineNode.TransitionEdges['\t'] = tabNode newlineNode.TransitionEdges['\t'] = tabNode
tabNode.TransitionEdges['"'] = objKeyNode tabNode.TransitionEdges['"'] = objKeyNode
// tabNode.TransitionEdges['\t'] = tabNode
spaceNode.TransitionEdges['"'] = stringNode
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
objKeyNode.TransitionEdges['"'] = objKeyEndNode objKeyNode.TransitionEdges['"'] = objKeyEndNode
objKeyNode.TransitionEdges[' '] = spaceNode
// objKeyNode.TransitionEdges['\t'] = tabNode
objKeyEndNode.TransitionEdges[':'] = colonNode objKeyEndNode.TransitionEdges[':'] = colonNode
objEndNode.TransitionEdges[' '] = spaceNode
colonNode.TransitionEdges['"'] = stringNode // where values should be
// this could be combined but the probs might change, we're alr doing a skip ahead
colonNode.TransitionEdges[' '] = spaceNode colonNode.TransitionEdges[' '] = spaceNode
// Leads to a value
spaceNode.TransitionEdges['"'] = stringNode
spaceNode.TransitionEdges['['] = listNode
spaceNode.TransitionEdges['{'] = objNode
for _, r := range validNumberRunes {
spaceNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
spaceNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
spaceNode.TransitionEdges[r] = nullNode
}
// Values
// string node
stringNode.TransitionEdges[rune(-1)] = stringNode stringNode.TransitionEdges[rune(-1)] = stringNode
stringNode.TransitionEdges['"'] = stringEndNode stringNode.TransitionEdges['"'] = stringEndNode
stringEndNode.TransitionEdges[','] = commaNode stringEndNode.TransitionEdges[','] = commaNode
stringEndNode.TransitionEdges['}'] = objEndNode stringEndNode.TransitionEdges['}'] = objEndNode
stringEndNode.TransitionEdges[']'] = listEndNode
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
numberNode.TransitionEdges[r] = numberNode
}
numberNode.TransitionEdges[','] = commaNode
numberNode.TransitionEdges['}'] = objEndNode
numberNode.TransitionEdges[']'] = listEndNode
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
// list node
listNode.TransitionEdges[','] = commaNode
listNode.TransitionEdges['"'] = stringNode
// squash states to a value
for _, r := range validNumberRunes {
listNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listNode.TransitionEdges[r] = nullNode
}
// null node
for _, r := range validNullRunes {
nullNode.TransitionEdges[r] = nullNode
}
nullNode.TransitionEdges[','] = commaNode
nullNode.TransitionEdges['}'] = objEndNode
nullNode.TransitionEdges[']'] = listEndNode
// list comma
// should point to values
listCommaNode.TransitionEdges['"'] = stringNode
listCommaNode.TransitionEdges[' '] = listCommaNode
listCommaNode.TransitionEdges['{'] = objNode
listCommaNode.TransitionEdges['\n'] = newlineNode
for _, r := range validNumberRunes {
listCommaNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listCommaNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listCommaNode.TransitionEdges[r] = nullNode
}
// bool node
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
boolNode.TransitionEdges['}'] = objEndNode
boolNode.TransitionEdges[']'] = listEndNode
boolNode.TransitionEdges[','] = commaNode
listEndNode.TransitionEdges['}'] = objEndNode
listEndNode.TransitionEdges[','] = commaNode
commaNode.TransitionEdges['{'] = objNode commaNode.TransitionEdges['{'] = objNode
commaNode.TransitionEdges['\n'] = newlineNode commaNode.TransitionEdges['\n'] = newlineNode
commaNode.TransitionEdges['\t'] = tabNode commaNode.TransitionEdges['\t'] = tabNode
commaNode.TransitionEdges['"'] = objKeyNode commaNode.TransitionEdges['"'] = objKeyNode
commaNode.TransitionEdges[' '] = spaceObjNode
spaceObjNode.TransitionEdges['"'] = objKeyNode
return startNode, stateToNodeMap, nil return startNode, stateToNodeMap, nil
} }

View File

@ -3,6 +3,8 @@ package sample
import ( import (
"fmt" "fmt"
"math" "math"
"runtime"
"time"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
@ -13,9 +15,17 @@ type PushdownSampler struct {
proc model.TextProcessor proc model.TextProcessor
stateToNodeMap map[JSONState]*PDANode stateToNodeMap map[JSONState]*PDANode
braceStack []rune braceStack []rune
stateCounter uint32
} }
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc) startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil { if err != nil {
panic(err) panic(err)
@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
if err != nil { if err != nil {
panic(err) panic(err)
} }
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode { // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
// token, err := proc.Decode([]int32{int32(id)}) // token, err := proc.Decode([]int32{int32(id)})
// if err != nil { // if err != nil {
@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
proc: proc, proc: proc,
stateToNodeMap: stateToNodeMap, stateToNodeMap: stateToNodeMap,
braceStack: []rune{}, braceStack: []rune{},
stateCounter: 0,
} }
} }
@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
} }
} }
return logits, nil return logits, nil
// return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
fmt.Println("switching to list comma", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate: case StateTerminate:
for i := range logits { for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) { if s.proc.Is(uint32(i), model.SpecialEOS) {
@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
} }
return logits, nil return logits, nil
// case StateInStringEnd:
// return logits, nil
default: default:
fmt.Println("masking logits current state", s.curNode.State) fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("update state", s.curNode.State) fmt.Println("update state", s.curNode.State)
// TODO: need to handle end states and entering object case // TODO: need to handle end states and entering object case, and list case
if s.curNode.State == StateInObjectEnd { if s.curNode.State == StateInObjectEnd {
fmt.Println("in object end") fmt.Println("in object end")
if len(s.braceStack) > 0 { if len(s.braceStack) > 0 {
@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
if err != nil { if err != nil {
return err return err
} }
// TODO: should force closing for all braces
for _, r := range mappedString { for _, r := range mappedString {
if r == rune('{') { if r == rune('{') {
s.braceStack = append(s.braceStack, r) s.braceStack = append(s.braceStack, r)
} }
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') { if r == rune('}') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') { if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
return fmt.Errorf("unmatched closing brace") return fmt.Errorf("unmatched closing brace")
} }
s.braceStack = s.braceStack[:len(s.braceStack)-1] s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
}
if r == rune(']') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
return fmt.Errorf("unmatched closing brace")
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
} }
} }
for _, tokenID := range tokenSlice { for _, tokenID := range tokenSlice {
// transition to the next node // transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok { if !ok {
return fmt.Errorf("invalid token: %q", mappedString) return fmt.Errorf("invalid token: %q", mappedString)
} }
fmt.Println("transitioning to", nextNode) fmt.Println("transitioning to", nextNodeState)
s.curNode = s.stateToNodeMap[nextNode]
// TODO: add a penalty for staying in the same state too long
if nextNodeState == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = s.stateToNodeMap[nextNodeState]
} }
return nil return nil
} }