From e93db4d20edd7f483122b3fe4b12d3433759f5df Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 27 Jan 2025 16:33:55 -0800 Subject: [PATCH] WIP --- sample/fast_json.go | 49 +++++++++---- sample/pushdown_automata.go | 133 ++++++++++++++++++++++++++++++++++-- sample/pushdown_runner.go | 61 ++++++++++++++--- 3 files changed, 214 insertions(+), 29 deletions(-) diff --git a/sample/fast_json.go b/sample/fast_json.go index b5f9088cc..6104b8b3e 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -22,12 +22,18 @@ const ( StateInFloat StateInBool StateInNull - StateInArray StateInColon StateInComma StateInTab StateInSpace + StateInObjSpace + StateInList + StateInListComma + StateListEnd + StateInListEnd StateInNewline + StateInNumber + StateInNumberEnd StateInStringEnd StateInObjectKeyEnd StateTerminate @@ -42,42 +48,54 @@ func (s JSONState) String() string { return "StateInObject" case StateInObjectKey: return "StateInObjectKey" - case StateInString: - return "StateInString" case StateNewline: return "StateNewline" case StateTab: return "StateTab" case StateSpace: return "StateSpace" + case StateInString: + return "StateInString" case StateInInt: return "StateInInt" case StateInFloat: return "StateInFloat" - case StateInColon: - return "StateInColon" case StateInBool: return "StateInBool" case StateInNull: return "StateInNull" - case StateInArray: - return "StateInArray" - case StateInObjectEnd: - return "StateInObjectEnd" + case StateInColon: + return "StateInColon" case StateInComma: return "StateInComma" case StateInTab: return "StateInTab" - case StateInObjectKeyEnd: - return "StateInObjectKeyEnd" - case StateInNewline: - return "StateInNewline" case StateInSpace: return "StateInSpace" - case StateTerminate: - return "StateTerminate" + case StateInObjSpace: + return "StateInObjSpace" + case StateInList: + return "StateInList" + case StateInListComma: + return "StateInListComma" + case StateListEnd: + return "StateListEnd" + case StateInListEnd: + return "StateInListEnd" + case StateInNewline: + return "StateInNewline" + case StateInNumber: + return "StateInNumber" + case StateInNumberEnd: + return "StateInNumberEnd" case StateInStringEnd: return "StateInStringEnd" + case StateInObjectKeyEnd: + return "StateInObjectKeyEnd" + case StateTerminate: + return "StateTerminate" + case StateInObjectEnd: + return "StateInObjectEnd" default: return fmt.Sprintf("Unknown state: %d", s) } @@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 { func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) { // fmt.Printf("Masking logits with valid states: %v\n", validStates) + // todo: this can prob be more efficient for i := range logits { isValid := false for _, token := range validStates { diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index 111a91037..d58f23cc4 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -8,6 +8,15 @@ import ( var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','} +var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} +var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} + +var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'} + +var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'} + +var validNullRunes = []rune{'n', 'u', 'l', 'l'} + type PDANode struct { State JSONState TransitionEdges map[rune]*PDANode @@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err spaceNode := NewPDANode(StateInSpace) stateToNodeMap[StateInSpace] = spaceNode + spaceObjNode := NewPDANode(StateInObjSpace) + stateToNodeMap[StateInObjSpace] = spaceObjNode + tabNode := NewPDANode(StateInTab) stateToNodeMap[StateInTab] = tabNode @@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stringEndNode := NewPDANode(StateInStringEnd) stateToNodeMap[StateInStringEnd] = stringEndNode - // terminateNode := NewNode(StateTerminate) + listNode := NewPDANode(StateInList) + stateToNodeMap[StateInList] = listNode + + listCommaNode := NewPDANode(StateInListComma) + stateToNodeMap[StateInListComma] = listCommaNode + + listEndNode := NewPDANode(StateListEnd) + stateToNodeMap[StateListEnd] = listEndNode + + numberNode := NewPDANode(StateInNumber) + stateToNodeMap[StateInNumber] = numberNode + + boolNode := NewPDANode(StateInBool) + stateToNodeMap[StateInBool] = boolNode + + nullNode := NewPDANode(StateInNull) + stateToNodeMap[StateInNull] = nullNode + + // Defined with structured outputs only + intNode := NewPDANode(StateInInt) + stateToNodeMap[StateInInt] = intNode + + // TODO: + // consider adding a node to just point to values, could be good to compute that + // mask rather than many different nodes // Connect nodes // TODO: if all are single tokens then this can just be connected instead of defining the token @@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err objNode.TransitionEdges['"'] = objKeyNode objNode.TransitionEdges['\n'] = newlineNode + // objNode.TransitionEdges['\t'] = tabNode newlineNode.TransitionEdges['"'] = objKeyNode newlineNode.TransitionEdges['\t'] = tabNode tabNode.TransitionEdges['"'] = objKeyNode - - spaceNode.TransitionEdges['"'] = stringNode + // tabNode.TransitionEdges['\t'] = tabNode objKeyNode.TransitionEdges[rune(-1)] = objKeyNode objKeyNode.TransitionEdges['"'] = objKeyEndNode - objKeyNode.TransitionEdges[' '] = spaceNode - // objKeyNode.TransitionEdges['\t'] = tabNode objKeyEndNode.TransitionEdges[':'] = colonNode + objEndNode.TransitionEdges[' '] = spaceNode - colonNode.TransitionEdges['"'] = stringNode + // where values should be + // this could be combined but the probs might change, we're alr doing a skip ahead colonNode.TransitionEdges[' '] = spaceNode + // Leads to a value + spaceNode.TransitionEdges['"'] = stringNode + spaceNode.TransitionEdges['['] = listNode + spaceNode.TransitionEdges['{'] = objNode + + for _, r := range validNumberRunes { + spaceNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + spaceNode.TransitionEdges[r] = boolNode + } + + for _, r := range validNullRunes { + spaceNode.TransitionEdges[r] = nullNode + } + + // Values + // string node stringNode.TransitionEdges[rune(-1)] = stringNode stringNode.TransitionEdges['"'] = stringEndNode stringEndNode.TransitionEdges[','] = commaNode stringEndNode.TransitionEdges['}'] = objEndNode + stringEndNode.TransitionEdges[']'] = listEndNode + + // TODO: add counters for allowable number of decimals, e, E, etc + // number node + for _, r := range validNumberRunes { + numberNode.TransitionEdges[r] = numberNode + } + numberNode.TransitionEdges[','] = commaNode + numberNode.TransitionEdges['}'] = objEndNode + numberNode.TransitionEdges[']'] = listEndNode + + for _, r := range validBoolRunes { + boolNode.TransitionEdges[r] = boolNode + } + + // list node + listNode.TransitionEdges[','] = commaNode + listNode.TransitionEdges['"'] = stringNode + // squash states to a value + for _, r := range validNumberRunes { + listNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + listNode.TransitionEdges[r] = boolNode + } + for _, r := range validNullRunes { + listNode.TransitionEdges[r] = nullNode + } + + // null node + for _, r := range validNullRunes { + nullNode.TransitionEdges[r] = nullNode + } + nullNode.TransitionEdges[','] = commaNode + nullNode.TransitionEdges['}'] = objEndNode + nullNode.TransitionEdges[']'] = listEndNode + + // list comma + // should point to values + listCommaNode.TransitionEdges['"'] = stringNode + listCommaNode.TransitionEdges[' '] = listCommaNode + listCommaNode.TransitionEdges['{'] = objNode + listCommaNode.TransitionEdges['\n'] = newlineNode + + for _, r := range validNumberRunes { + listCommaNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + listCommaNode.TransitionEdges[r] = boolNode + } + for _, r := range validNullRunes { + listCommaNode.TransitionEdges[r] = nullNode + } + + // bool node + for _, r := range validBoolRunes { + boolNode.TransitionEdges[r] = boolNode + } + boolNode.TransitionEdges['}'] = objEndNode + boolNode.TransitionEdges[']'] = listEndNode + boolNode.TransitionEdges[','] = commaNode + + listEndNode.TransitionEdges['}'] = objEndNode + listEndNode.TransitionEdges[','] = commaNode commaNode.TransitionEdges['{'] = objNode commaNode.TransitionEdges['\n'] = newlineNode commaNode.TransitionEdges['\t'] = tabNode commaNode.TransitionEdges['"'] = objKeyNode + commaNode.TransitionEdges[' '] = spaceObjNode + + spaceObjNode.TransitionEdges['"'] = objKeyNode return startNode, stateToNodeMap, nil } diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index 4d27f93dc..a97b5f29a 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -3,6 +3,8 @@ package sample import ( "fmt" "math" + "runtime" + "time" "github.com/ollama/ollama/model" ) @@ -13,9 +15,17 @@ type PushdownSampler struct { proc model.TextProcessor stateToNodeMap map[JSONState]*PDANode braceStack []rune + stateCounter uint32 } func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { + start := time.Now() + + var m runtime.MemStats + runtime.ReadMemStats(&m) + before := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) + startNode, stateToNodeMap, err := BuildGraph(proc) if err != nil { panic(err) @@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { if err != nil { panic(err) } + runtime.ReadMemStats(&m) + after := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) + fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Graph build time = %v\n", time.Since(start)) // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode { // token, err := proc.Decode([]int32{int32(id)}) // if err != nil { @@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { proc: proc, stateToNodeMap: stateToNodeMap, braceStack: []rune{}, + stateCounter: 0, } } @@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } } return logits, nil - // return logits, nil + + case StateInComma: + peek := s.braceStack[len(s.braceStack)-1] + if peek == rune('[') { + s.curNode = s.stateToNodeMap[StateInListComma] + fmt.Println("switching to list comma", s.curNode.State) + } + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + case StateTerminate: for i := range logits { if s.proc.Is(uint32(i), model.SpecialEOS) { @@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } return logits, nil - // case StateInStringEnd: - - // return logits, nil default: fmt.Println("masking logits current state", s.curNode.State) logits, err := s.maskLogits(logits, s.curNode) @@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { fmt.Println("update state", s.curNode.State) - // TODO: need to handle end states and entering object case + // TODO: need to handle end states and entering object case, and list case if s.curNode.State == StateInObjectEnd { fmt.Println("in object end") if len(s.braceStack) > 0 { @@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { if err != nil { return err } + // TODO: should force closing for all braces for _, r := range mappedString { if r == rune('{') { s.braceStack = append(s.braceStack, r) } + if r == rune('[') { + s.braceStack = append(s.braceStack, r) + } if r == rune('}') { if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') { return fmt.Errorf("unmatched closing brace") } s.braceStack = s.braceStack[:len(s.braceStack)-1] + fmt.Println("popping brace stack", s.braceStack) + } + + if r == rune(']') { + if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') { + return fmt.Errorf("unmatched closing brace") + } + s.braceStack = s.braceStack[:len(s.braceStack)-1] + fmt.Println("popping brace stack", s.braceStack) } } for _, tokenID := range tokenSlice { // transition to the next node - nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] + nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID] if !ok { return fmt.Errorf("invalid token: %q", mappedString) } - fmt.Println("transitioning to", nextNode) - s.curNode = s.stateToNodeMap[nextNode] + fmt.Println("transitioning to", nextNodeState) + + // TODO: add a penalty for staying in the same state too long + if nextNodeState == s.curNode.State { + s.stateCounter++ + } else { + s.stateCounter = 0 + } + s.curNode = s.stateToNodeMap[nextNodeState] } return nil }