wip!
This commit is contained in:
parent
6ba557f25b
commit
a2a73ce5e0
@ -25,10 +25,13 @@ const (
|
|||||||
StateInArray
|
StateInArray
|
||||||
StateInColon
|
StateInColon
|
||||||
StateInComma
|
StateInComma
|
||||||
|
StateInTab
|
||||||
|
StateInSpace
|
||||||
|
StateInNewline
|
||||||
StateInStringEnd
|
StateInStringEnd
|
||||||
StateInObjectKeyEnd
|
StateInObjectKeyEnd
|
||||||
StateTerminate
|
StateTerminate
|
||||||
StateEnd
|
StateInObjectEnd
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s JSONState) String() string {
|
func (s JSONState) String() string {
|
||||||
@ -59,12 +62,18 @@ func (s JSONState) String() string {
|
|||||||
return "StateInNull"
|
return "StateInNull"
|
||||||
case StateInArray:
|
case StateInArray:
|
||||||
return "StateInArray"
|
return "StateInArray"
|
||||||
case StateEnd:
|
case StateInObjectEnd:
|
||||||
return "StateEnd"
|
return "StateInObjectEnd"
|
||||||
case StateInComma:
|
case StateInComma:
|
||||||
return "StateInComma"
|
return "StateInComma"
|
||||||
|
case StateInTab:
|
||||||
|
return "StateInTab"
|
||||||
case StateInObjectKeyEnd:
|
case StateInObjectKeyEnd:
|
||||||
return "StateInObjectKeyEnd"
|
return "StateInObjectKeyEnd"
|
||||||
|
case StateInNewline:
|
||||||
|
return "StateInNewline"
|
||||||
|
case StateInSpace:
|
||||||
|
return "StateInSpace"
|
||||||
case StateTerminate:
|
case StateTerminate:
|
||||||
return "StateTerminate"
|
return "StateTerminate"
|
||||||
case StateInStringEnd:
|
case StateInStringEnd:
|
||||||
@ -124,13 +133,14 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
|
|
||||||
// fmt.Println("tokenSlice", tokenSlice)
|
// fmt.Println("tokenSlice", tokenSlice)
|
||||||
// todo: account for strings here
|
// todo: account for strings here
|
||||||
|
|
||||||
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// only move to terminate state if stack is empty
|
// only move to terminate state if stack is empty
|
||||||
if s.curNode.State == StateEnd {
|
if s.curNode.State == StateInObjectEnd {
|
||||||
fmt.Println("debug: node.State", s.curNode.State)
|
fmt.Println("debug: node.State", s.curNode.State)
|
||||||
if len(s.stack) > 0 {
|
if len(s.stack) > 0 {
|
||||||
s.stack = s.stack[:len(s.stack)-1]
|
s.stack = s.stack[:len(s.stack)-1]
|
||||||
|
175
sample/pushdown_automata.go
Normal file
175
sample/pushdown_automata.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||||
|
|
||||||
|
type PDANode struct {
|
||||||
|
State JSONState
|
||||||
|
TransitionEdges map[rune]*PDANode
|
||||||
|
MaskTokenIDToNode map[int32]JSONState
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPDANode(state JSONState) *PDANode {
|
||||||
|
return &PDANode{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*PDANode),
|
||||||
|
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||||
|
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||||
|
|
||||||
|
startNode := NewPDANode(StateStart)
|
||||||
|
stateToNodeMap[StateStart] = startNode
|
||||||
|
|
||||||
|
objNode := NewPDANode(StateInObject)
|
||||||
|
stateToNodeMap[StateInObject] = objNode
|
||||||
|
|
||||||
|
objEndNode := NewPDANode(StateInObjectEnd)
|
||||||
|
stateToNodeMap[StateInObjectEnd] = objEndNode
|
||||||
|
|
||||||
|
objKeyNode := NewPDANode(StateInObjectKey)
|
||||||
|
stateToNodeMap[StateInObjectKey] = objKeyNode
|
||||||
|
|
||||||
|
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
|
||||||
|
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
|
||||||
|
|
||||||
|
colonNode := NewPDANode(StateInColon)
|
||||||
|
stateToNodeMap[StateInColon] = colonNode
|
||||||
|
|
||||||
|
commaNode := NewPDANode(StateInComma)
|
||||||
|
stateToNodeMap[StateInComma] = commaNode
|
||||||
|
|
||||||
|
newlineNode := NewPDANode(StateInNewline)
|
||||||
|
stateToNodeMap[StateInNewline] = newlineNode
|
||||||
|
|
||||||
|
spaceNode := NewPDANode(StateInSpace)
|
||||||
|
stateToNodeMap[StateInSpace] = spaceNode
|
||||||
|
|
||||||
|
tabNode := NewPDANode(StateInTab)
|
||||||
|
stateToNodeMap[StateInTab] = tabNode
|
||||||
|
|
||||||
|
stringNode := NewPDANode(StateInString)
|
||||||
|
stateToNodeMap[StateInString] = stringNode
|
||||||
|
|
||||||
|
stringEndNode := NewPDANode(StateInStringEnd)
|
||||||
|
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||||
|
|
||||||
|
// terminateNode := NewNode(StateTerminate)
|
||||||
|
|
||||||
|
// Connect nodes
|
||||||
|
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||||
|
startNode.TransitionEdges['{'] = objNode
|
||||||
|
|
||||||
|
objNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
objNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
|
||||||
|
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
newlineNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
|
tabNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
|
||||||
|
spaceNode.TransitionEdges['"'] = stringNode
|
||||||
|
|
||||||
|
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||||
|
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||||
|
objKeyNode.TransitionEdges[' '] = spaceNode
|
||||||
|
// objKeyNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
|
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||||
|
|
||||||
|
colonNode.TransitionEdges['"'] = stringNode
|
||||||
|
colonNode.TransitionEdges[' '] = spaceNode
|
||||||
|
|
||||||
|
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||||
|
stringNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
|
||||||
|
stringEndNode.TransitionEdges[','] = commaNode
|
||||||
|
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
|
||||||
|
commaNode.TransitionEdges['{'] = objNode
|
||||||
|
commaNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
commaNode.TransitionEdges['\t'] = tabNode
|
||||||
|
commaNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
|
||||||
|
return startNode, stateToNodeMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||||
|
|
||||||
|
vocab := proc.GetVocabulary()
|
||||||
|
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, node := range stateToNodeMap {
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token := decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
valid := true
|
||||||
|
curNode := node
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
// fmt.Println("special rune", r, "consumed")
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return true, nextNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return true, nextNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
147
sample/pushdown_runner.go
Normal file
147
sample/pushdown_runner.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PushdownSampler struct {
|
||||||
|
// stateful
|
||||||
|
curNode *PDANode
|
||||||
|
proc model.TextProcessor
|
||||||
|
stateToNodeMap map[JSONState]*PDANode
|
||||||
|
braceStack []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||||
|
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = PreComputeValidStates(stateToNodeMap, proc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||||
|
// token, err := proc.Decode([]int32{int32(id)})
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
// fmt.Println("id", id, "node", node, "token", token)
|
||||||
|
// }
|
||||||
|
// time.Sleep(10 * time.Second)
|
||||||
|
return &PushdownSampler{
|
||||||
|
curNode: startNode,
|
||||||
|
proc: proc,
|
||||||
|
stateToNodeMap: stateToNodeMap,
|
||||||
|
braceStack: []rune{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||||
|
fmt.Println("sample:", s.curNode.State)
|
||||||
|
|
||||||
|
switch s.curNode.State {
|
||||||
|
case StateInObjectEnd:
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
valid, err := s.proc.Encode("}")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for i := range logits {
|
||||||
|
for _, token := range valid {
|
||||||
|
if i != int(token) {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
// return logits, nil
|
||||||
|
case StateTerminate:
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
// case StateInStringEnd:
|
||||||
|
|
||||||
|
// return logits, nil
|
||||||
|
default:
|
||||||
|
fmt.Println("masking logits current state", s.curNode.State)
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||||
|
fmt.Println("update state", s.curNode.State)
|
||||||
|
|
||||||
|
// TODO: need to handle end states and entering object case
|
||||||
|
if s.curNode.State == StateInObjectEnd {
|
||||||
|
fmt.Println("in object end")
|
||||||
|
if len(s.braceStack) > 0 {
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
// TODO: return here?
|
||||||
|
}
|
||||||
|
// need this cause there could be multiple transitions
|
||||||
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, r := range mappedString {
|
||||||
|
if r == rune('{') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
|
if r == rune('}') {
|
||||||
|
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||||
|
return fmt.Errorf("unmatched closing brace")
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tokenID := range tokenSlice {
|
||||||
|
// transition to the next node
|
||||||
|
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
|
}
|
||||||
|
fmt.Println("transitioning to", nextNode)
|
||||||
|
s.curNode = s.stateToNodeMap[nextNode]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||||
|
for i := range logits {
|
||||||
|
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||||
|
if !exists {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add penalties for string \n stuff
|
Loading…
x
Reference in New Issue
Block a user