WIP
This commit is contained in:
parent
a2a73ce5e0
commit
e93db4d20e
@ -22,12 +22,18 @@ const (
|
|||||||
StateInFloat
|
StateInFloat
|
||||||
StateInBool
|
StateInBool
|
||||||
StateInNull
|
StateInNull
|
||||||
StateInArray
|
|
||||||
StateInColon
|
StateInColon
|
||||||
StateInComma
|
StateInComma
|
||||||
StateInTab
|
StateInTab
|
||||||
StateInSpace
|
StateInSpace
|
||||||
|
StateInObjSpace
|
||||||
|
StateInList
|
||||||
|
StateInListComma
|
||||||
|
StateListEnd
|
||||||
|
StateInListEnd
|
||||||
StateInNewline
|
StateInNewline
|
||||||
|
StateInNumber
|
||||||
|
StateInNumberEnd
|
||||||
StateInStringEnd
|
StateInStringEnd
|
||||||
StateInObjectKeyEnd
|
StateInObjectKeyEnd
|
||||||
StateTerminate
|
StateTerminate
|
||||||
@ -42,42 +48,54 @@ func (s JSONState) String() string {
|
|||||||
return "StateInObject"
|
return "StateInObject"
|
||||||
case StateInObjectKey:
|
case StateInObjectKey:
|
||||||
return "StateInObjectKey"
|
return "StateInObjectKey"
|
||||||
case StateInString:
|
|
||||||
return "StateInString"
|
|
||||||
case StateNewline:
|
case StateNewline:
|
||||||
return "StateNewline"
|
return "StateNewline"
|
||||||
case StateTab:
|
case StateTab:
|
||||||
return "StateTab"
|
return "StateTab"
|
||||||
case StateSpace:
|
case StateSpace:
|
||||||
return "StateSpace"
|
return "StateSpace"
|
||||||
|
case StateInString:
|
||||||
|
return "StateInString"
|
||||||
case StateInInt:
|
case StateInInt:
|
||||||
return "StateInInt"
|
return "StateInInt"
|
||||||
case StateInFloat:
|
case StateInFloat:
|
||||||
return "StateInFloat"
|
return "StateInFloat"
|
||||||
case StateInColon:
|
|
||||||
return "StateInColon"
|
|
||||||
case StateInBool:
|
case StateInBool:
|
||||||
return "StateInBool"
|
return "StateInBool"
|
||||||
case StateInNull:
|
case StateInNull:
|
||||||
return "StateInNull"
|
return "StateInNull"
|
||||||
case StateInArray:
|
case StateInColon:
|
||||||
return "StateInArray"
|
return "StateInColon"
|
||||||
case StateInObjectEnd:
|
|
||||||
return "StateInObjectEnd"
|
|
||||||
case StateInComma:
|
case StateInComma:
|
||||||
return "StateInComma"
|
return "StateInComma"
|
||||||
case StateInTab:
|
case StateInTab:
|
||||||
return "StateInTab"
|
return "StateInTab"
|
||||||
case StateInObjectKeyEnd:
|
|
||||||
return "StateInObjectKeyEnd"
|
|
||||||
case StateInNewline:
|
|
||||||
return "StateInNewline"
|
|
||||||
case StateInSpace:
|
case StateInSpace:
|
||||||
return "StateInSpace"
|
return "StateInSpace"
|
||||||
case StateTerminate:
|
case StateInObjSpace:
|
||||||
return "StateTerminate"
|
return "StateInObjSpace"
|
||||||
|
case StateInList:
|
||||||
|
return "StateInList"
|
||||||
|
case StateInListComma:
|
||||||
|
return "StateInListComma"
|
||||||
|
case StateListEnd:
|
||||||
|
return "StateListEnd"
|
||||||
|
case StateInListEnd:
|
||||||
|
return "StateInListEnd"
|
||||||
|
case StateInNewline:
|
||||||
|
return "StateInNewline"
|
||||||
|
case StateInNumber:
|
||||||
|
return "StateInNumber"
|
||||||
|
case StateInNumberEnd:
|
||||||
|
return "StateInNumberEnd"
|
||||||
case StateInStringEnd:
|
case StateInStringEnd:
|
||||||
return "StateInStringEnd"
|
return "StateInStringEnd"
|
||||||
|
case StateInObjectKeyEnd:
|
||||||
|
return "StateInObjectKeyEnd"
|
||||||
|
case StateTerminate:
|
||||||
|
return "StateTerminate"
|
||||||
|
case StateInObjectEnd:
|
||||||
|
return "StateInObjectEnd"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("Unknown state: %d", s)
|
return fmt.Sprintf("Unknown state: %d", s)
|
||||||
}
|
}
|
||||||
@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 {
|
|||||||
|
|
||||||
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
||||||
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
||||||
|
// todo: this can prob be more efficient
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
isValid := false
|
isValid := false
|
||||||
for _, token := range validStates {
|
for _, token := range validStates {
|
||||||
|
@ -8,6 +8,15 @@ import (
|
|||||||
|
|
||||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||||
|
|
||||||
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
|
|
||||||
|
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||||
|
|
||||||
|
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||||
|
|
||||||
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||||
|
|
||||||
type PDANode struct {
|
type PDANode struct {
|
||||||
State JSONState
|
State JSONState
|
||||||
TransitionEdges map[rune]*PDANode
|
TransitionEdges map[rune]*PDANode
|
||||||
@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
spaceNode := NewPDANode(StateInSpace)
|
spaceNode := NewPDANode(StateInSpace)
|
||||||
stateToNodeMap[StateInSpace] = spaceNode
|
stateToNodeMap[StateInSpace] = spaceNode
|
||||||
|
|
||||||
|
spaceObjNode := NewPDANode(StateInObjSpace)
|
||||||
|
stateToNodeMap[StateInObjSpace] = spaceObjNode
|
||||||
|
|
||||||
tabNode := NewPDANode(StateInTab)
|
tabNode := NewPDANode(StateInTab)
|
||||||
stateToNodeMap[StateInTab] = tabNode
|
stateToNodeMap[StateInTab] = tabNode
|
||||||
|
|
||||||
@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
stringEndNode := NewPDANode(StateInStringEnd)
|
stringEndNode := NewPDANode(StateInStringEnd)
|
||||||
stateToNodeMap[StateInStringEnd] = stringEndNode
|
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||||
|
|
||||||
// terminateNode := NewNode(StateTerminate)
|
listNode := NewPDANode(StateInList)
|
||||||
|
stateToNodeMap[StateInList] = listNode
|
||||||
|
|
||||||
|
listCommaNode := NewPDANode(StateInListComma)
|
||||||
|
stateToNodeMap[StateInListComma] = listCommaNode
|
||||||
|
|
||||||
|
listEndNode := NewPDANode(StateListEnd)
|
||||||
|
stateToNodeMap[StateListEnd] = listEndNode
|
||||||
|
|
||||||
|
numberNode := NewPDANode(StateInNumber)
|
||||||
|
stateToNodeMap[StateInNumber] = numberNode
|
||||||
|
|
||||||
|
boolNode := NewPDANode(StateInBool)
|
||||||
|
stateToNodeMap[StateInBool] = boolNode
|
||||||
|
|
||||||
|
nullNode := NewPDANode(StateInNull)
|
||||||
|
stateToNodeMap[StateInNull] = nullNode
|
||||||
|
|
||||||
|
// Defined with structured outputs only
|
||||||
|
intNode := NewPDANode(StateInInt)
|
||||||
|
stateToNodeMap[StateInInt] = intNode
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// consider adding a node to just point to values, could be good to compute that
|
||||||
|
// mask rather than many different nodes
|
||||||
|
|
||||||
// Connect nodes
|
// Connect nodes
|
||||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||||
@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
|
|
||||||
objNode.TransitionEdges['"'] = objKeyNode
|
objNode.TransitionEdges['"'] = objKeyNode
|
||||||
objNode.TransitionEdges['\n'] = newlineNode
|
objNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
// objNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||||
newlineNode.TransitionEdges['\t'] = tabNode
|
newlineNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
tabNode.TransitionEdges['"'] = objKeyNode
|
tabNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
// tabNode.TransitionEdges['\t'] = tabNode
|
||||||
spaceNode.TransitionEdges['"'] = stringNode
|
|
||||||
|
|
||||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||||
objKeyNode.TransitionEdges[' '] = spaceNode
|
|
||||||
// objKeyNode.TransitionEdges['\t'] = tabNode
|
|
||||||
|
|
||||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||||
|
objEndNode.TransitionEdges[' '] = spaceNode
|
||||||
|
|
||||||
colonNode.TransitionEdges['"'] = stringNode
|
// where values should be
|
||||||
|
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||||
colonNode.TransitionEdges[' '] = spaceNode
|
colonNode.TransitionEdges[' '] = spaceNode
|
||||||
|
|
||||||
|
// Leads to a value
|
||||||
|
spaceNode.TransitionEdges['"'] = stringNode
|
||||||
|
spaceNode.TransitionEdges['['] = listNode
|
||||||
|
spaceNode.TransitionEdges['{'] = objNode
|
||||||
|
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Values
|
||||||
|
// string node
|
||||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||||
stringNode.TransitionEdges['"'] = stringEndNode
|
stringNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
|
||||||
stringEndNode.TransitionEdges[','] = commaNode
|
stringEndNode.TransitionEdges[','] = commaNode
|
||||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
stringEndNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
|
// number node
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
numberNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
numberNode.TransitionEdges[','] = commaNode
|
||||||
|
numberNode.TransitionEdges['}'] = objEndNode
|
||||||
|
numberNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// list node
|
||||||
|
listNode.TransitionEdges[','] = commaNode
|
||||||
|
listNode.TransitionEdges['"'] = stringNode
|
||||||
|
// squash states to a value
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
listNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
listNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
listNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// null node
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
nullNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
nullNode.TransitionEdges[','] = commaNode
|
||||||
|
nullNode.TransitionEdges['}'] = objEndNode
|
||||||
|
nullNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
// list comma
|
||||||
|
// should point to values
|
||||||
|
listCommaNode.TransitionEdges['"'] = stringNode
|
||||||
|
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||||
|
listCommaNode.TransitionEdges['{'] = objNode
|
||||||
|
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// bool node
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
boolNode.TransitionEdges['}'] = objEndNode
|
||||||
|
boolNode.TransitionEdges[']'] = listEndNode
|
||||||
|
boolNode.TransitionEdges[','] = commaNode
|
||||||
|
|
||||||
|
listEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
listEndNode.TransitionEdges[','] = commaNode
|
||||||
|
|
||||||
commaNode.TransitionEdges['{'] = objNode
|
commaNode.TransitionEdges['{'] = objNode
|
||||||
commaNode.TransitionEdges['\n'] = newlineNode
|
commaNode.TransitionEdges['\n'] = newlineNode
|
||||||
commaNode.TransitionEdges['\t'] = tabNode
|
commaNode.TransitionEdges['\t'] = tabNode
|
||||||
commaNode.TransitionEdges['"'] = objKeyNode
|
commaNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||||
|
|
||||||
|
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
|
||||||
return startNode, stateToNodeMap, nil
|
return startNode, stateToNodeMap, nil
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@ package sample
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
@ -13,9 +15,17 @@ type PushdownSampler struct {
|
|||||||
proc model.TextProcessor
|
proc model.TextProcessor
|
||||||
stateToNodeMap map[JSONState]*PDANode
|
stateToNodeMap map[JSONState]*PDANode
|
||||||
braceStack []rune
|
braceStack []rune
|
||||||
|
stateCounter uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||||
|
|
||||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
after := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||||
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||||
// token, err := proc.Decode([]int32{int32(id)})
|
// token, err := proc.Decode([]int32{int32(id)})
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
proc: proc,
|
proc: proc,
|
||||||
stateToNodeMap: stateToNodeMap,
|
stateToNodeMap: stateToNodeMap,
|
||||||
braceStack: []rune{},
|
braceStack: []rune{},
|
||||||
|
stateCounter: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return logits, nil
|
return logits, nil
|
||||||
// return logits, nil
|
|
||||||
|
case StateInComma:
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||||
|
fmt.Println("switching to list comma", s.curNode.State)
|
||||||
|
}
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
case StateTerminate:
|
case StateTerminate:
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||||
@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
return logits, nil
|
return logits, nil
|
||||||
|
|
||||||
// case StateInStringEnd:
|
|
||||||
|
|
||||||
// return logits, nil
|
|
||||||
default:
|
default:
|
||||||
fmt.Println("masking logits current state", s.curNode.State)
|
fmt.Println("masking logits current state", s.curNode.State)
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||||
fmt.Println("update state", s.curNode.State)
|
fmt.Println("update state", s.curNode.State)
|
||||||
|
|
||||||
// TODO: need to handle end states and entering object case
|
// TODO: need to handle end states and entering object case, and list case
|
||||||
if s.curNode.State == StateInObjectEnd {
|
if s.curNode.State == StateInObjectEnd {
|
||||||
fmt.Println("in object end")
|
fmt.Println("in object end")
|
||||||
if len(s.braceStack) > 0 {
|
if len(s.braceStack) > 0 {
|
||||||
@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// TODO: should force closing for all braces
|
||||||
for _, r := range mappedString {
|
for _, r := range mappedString {
|
||||||
if r == rune('{') {
|
if r == rune('{') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
}
|
}
|
||||||
|
if r == rune('[') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
if r == rune('}') {
|
if r == rune('}') {
|
||||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||||
return fmt.Errorf("unmatched closing brace")
|
return fmt.Errorf("unmatched closing brace")
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
fmt.Println("popping brace stack", s.braceStack)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == rune(']') {
|
||||||
|
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
|
||||||
|
return fmt.Errorf("unmatched closing brace")
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
fmt.Println("popping brace stack", s.braceStack)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, tokenID := range tokenSlice {
|
for _, tokenID := range tokenSlice {
|
||||||
// transition to the next node
|
// transition to the next node
|
||||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid token: %q", mappedString)
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
}
|
}
|
||||||
fmt.Println("transitioning to", nextNode)
|
fmt.Println("transitioning to", nextNodeState)
|
||||||
s.curNode = s.stateToNodeMap[nextNode]
|
|
||||||
|
// TODO: add a penalty for staying in the same state too long
|
||||||
|
if nextNodeState == s.curNode.State {
|
||||||
|
s.stateCounter++
|
||||||
|
} else {
|
||||||
|
s.stateCounter = 0
|
||||||
|
}
|
||||||
|
s.curNode = s.stateToNodeMap[nextNodeState]
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user