diff --git a/sample/constrained.go b/sample/constrained.go new file mode 100644 index 000000000..e148ec3f9 --- /dev/null +++ b/sample/constrained.go @@ -0,0 +1,49 @@ +package sample + +import ( + "github.com/ollama/ollama/model" +) + +type ConstrainedSampler struct { + schema *Schema + propIdx int + propToNodeMap map[string]*PDA + pdaSampler *PushdownSampler + decodedToks []string +} + +func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) { + pdaSampler, err := NewPushdownSampler(proc) + if err != nil { + return nil, err + } + + // if schema == nil { + return &ConstrainedSampler{ + schema: nil, + propIdx: -1, + propToNodeMap: nil, + pdaSampler: pdaSampler, + }, nil + +} + +func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) { + if s.schema == nil { + return s.pdaSampler.Apply(logits) + } + + return nil, nil +} + +func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error { + if err := s.pdaSampler.UpdateState(tokenSlice); err != nil { + return err + } + + if s.schema == nil { + return nil + } + + return nil +} diff --git a/sample/feedback.txt b/sample/feedback.txt new file mode 100644 index 000000000..42c168138 --- /dev/null +++ b/sample/feedback.txt @@ -0,0 +1,32 @@ +// Feedback from code review: + +// pushdown_automata.go: +// 1. The BuildGraph function is quite long and could be split into smaller, more focused functions +// 2. Consider using constants instead of magic runes like rune(-1) for sentinel values +// 3. The state machine transitions could be defined more declaratively, perhaps in a config +// 4. The stringInvalidRunes list needs to handle escape sequences properly +// 5. The graph building could be optimized to avoid duplicate nodes/transitions +// 6. Consider adding validation for max nesting depth of braces/brackets +// 7. The CreateMask function is doing a lot - could be split into smaller pieces +// 8. isRuneValid has a "garbage interface" per TODO - needs cleaner design + +// pushdown_runner.go: +// 1. The Apply method has a lot of duplicated logic around EOS handling +// 2. The UpdateState method could use more granular error messages +// 3. The braceStack validation could be moved to a separate validator +// 4. Consider adding max length limits for strings/numbers +// 5. The stateCounter isn't being used effectively yet +// 6. Need to add penalties for staying in same state too long +// 7. The maskLogits function could be optimized to avoid allocations +// 8. Missing proper cleanup/reset functionality +// 9. Error handling could be more consistent throughout +// 10. Consider adding debug logging levels instead of raw fmt.Println + +// General improvements needed: +// - More comprehensive testing, especially edge cases +// - Better documentation of state machine transitions +// - Performance optimization for large inputs +// - Memory usage optimization for the graph structure +// - Cleaner interfaces between components +// - More robust error handling and recovery + diff --git a/sample/fused_mask_sample.go b/sample/fused_mask_sample.go new file mode 100644 index 000000000..d01feac67 --- /dev/null +++ b/sample/fused_mask_sample.go @@ -0,0 +1,11 @@ +package sample + +// type fusedMaskSampler struct{} + +// func FusedMaskSampler() Sampler { +// return fusedMaskSampler{} +// } + +// func (f fusedMaskSampler) Sample(logits []float64) (int, error) { +// return int(logits[0]), nil +// } diff --git a/sample/greedy.go b/sample/greedy.go index 4d110f021..5019fffd6 100644 --- a/sample/greedy.go +++ b/sample/greedy.go @@ -8,6 +8,19 @@ func Greedy() Sampler { return greedy{} } -func (s greedy) Sample(t []float64) (int, error) { - return floats.MaxIdx(t), nil +func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + var err error + for _, t := range transforms { + logits64, err = t.Apply(logits64) + if err != nil { + return -1, err + } + } + + return floats.MaxIdx(logits64), nil } diff --git a/sample/fast_json.go b/sample/json_types.go similarity index 89% rename from sample/fast_json.go rename to sample/json_types.go index 925de06ae..bf52b8c9d 100644 --- a/sample/fast_json.go +++ b/sample/json_types.go @@ -23,7 +23,9 @@ const ( StateInColon StateInComma StateInTab - StateInSpace + StateInSpaceToValue + StateInSpaceEndValue + StateInNewlineEndValue StateInObjSpace StateInList StateInListComma @@ -57,7 +59,9 @@ var JSONStates = []JSONState{ StateInColon, StateInComma, StateInTab, - StateInSpace, + StateInSpaceToValue, + StateInSpaceEndValue, + StateInNewlineEndValue, StateInObjSpace, StateInList, StateInListComma, @@ -107,7 +111,7 @@ func (s JSONState) String() string { return "StateInComma" case StateInTab: return "StateInTab" - case StateInSpace: + case StateInSpaceToValue: return "StateInSpace" case StateInObjSpace: return "StateInObjSpace" @@ -121,6 +125,8 @@ func (s JSONState) String() string { return "StateInListEnd" case StateInNewline: return "StateInNewline" + case StateInNewlineEndValue: + return "StateInNewlineEndValue" case StateInNumber: return "StateInNumber" case StateInNumberEnd: @@ -129,6 +135,8 @@ func (s JSONState) String() string { return "StateInStringEnd" case StateInObjectKeyEnd: return "StateInObjectKeyEnd" + case StateInSpaceEndValue: + return "StateInSpaceEndValue" case StateTerminate: return "StateTerminate" case StateInObjectEnd: diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index 19b237526..d9ccc151e 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -6,8 +6,35 @@ import ( "github.com/ollama/ollama/model" ) +/* +Key JSON rules to consider: + +1. Whitespace handling: + - Need to handle all valid JSON whitespace characters (\r, spaces between tokens) + - Current code only handles some whitespace cases + +2. Number validation: + - Need proper validation for special number cases like -0 + - Should handle .5 style decimals + - Need limits on scientific notation (e, E) + +3. String escaping: + - Currently marks \ as invalid but should allow escaped sequences: + - \" + - \n + - \u1234 unicode escapes + +4. Empty object/array transitions: + - Direct {} and [] cases could be more explicit + - Need clear transitions for these edge cases + +5. Nested depth limits: + - No protection against excessive nesting + - Could cause stack overflow with deeply nested structures +*/ + // TODO: / should be valid but an escape character -var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} +var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'} var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} @@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'} var validNullRunes = []rune{'n', 'u', 'l', 'l'} -type PDANode struct { +type PDA struct { State JSONState - TransitionEdges map[rune]*PDANode - MaskTokenIDToNode map[int32]*PDANode + TransitionEdges map[rune]*PDA + MaskTokenIDToNode map[int32]*PDA } -func NewPDANode(state JSONState) *PDANode { - return &PDANode{ +func NewPDANode(state JSONState) *PDA { + return &PDA{ State: state, - TransitionEdges: make(map[rune]*PDANode), - MaskTokenIDToNode: make(map[int32]*PDANode), + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), } } -func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) { - stateToNodeMap := make(map[JSONState]*PDANode) - - // TODO: make this a loop +type PDAGraphBuilder struct { + proc model.TextProcessor + decodedToks []string + stateToNodeMap map[JSONState]*PDA +} +func (b *PDAGraphBuilder) BuildGraph() error { + stateToNodeMap := make(map[JSONState]*PDA) for _, state := range JSONStates { stateToNodeMap[state] = NewPDANode(state) } - // TODO: - // consider adding a node to just point to values, could be good to compute that - // mask rather than many different nodes stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] @@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] - //new line + // new line stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + + // new line end value + stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + + stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + // TODO: see if this is needed for formatting + stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] @@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err // where values should be // this could be combined but the probl might change, we're alr doing a skip ahead - stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace] + stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] - addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) + b.addValueConnections(stateToNodeMap[StateInColon]) // Leads to a value - stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList] - stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject] - addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap) - stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList] + stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] + b.addValueConnections(stateToNodeMap[StateInSpaceToValue]) + stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] // Values // string node @@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd] // String end node - addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) + b.addEnds(stateToNodeMap[StateInStringEnd]) + stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // TODO: add counters for allowable number of decimals, e, E, etc // number node for _, r := range validNumberRunes { stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] } - addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) - - // bool node - for _, r := range validBoolRunes { - stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] - } - addEnds(stateToNodeMap[StateInBool], stateToNodeMap) - stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace] + b.addEnds(stateToNodeMap[StateInNumber]) + stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // list node stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] + + // list end node + stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] + stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + // empty list stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] - addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) + b.addValueConnections(stateToNodeMap[StateInList]) // null node for _, r := range validNullRunes { stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull] } - addEnds(stateToNodeMap[StateInNull], stateToNodeMap) + b.addEnds(stateToNodeMap[StateInNull]) + stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] + stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // list comma // should point to values stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] - addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) + b.addValueConnections(stateToNodeMap[StateInListComma]) // list object end stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + // TODO: not sure if this is needed + stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // bool node for _, r := range validBoolRunes { stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] } stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] - addEnds(stateToNodeMap[StateInBool], stateToNodeMap) - - stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] - stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] + b.addEnds(stateToNodeMap[StateInBool]) + stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + // comma node stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] - stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] - stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab] + stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] - stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] - stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + // space end value + stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] - return stateToNodeMap[StateStart], stateToNodeMap, nil + b.stateToNodeMap = stateToNodeMap + if err := b.preComputeValidStates(); err != nil { + return err + } + return nil } -func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { - node.TransitionEdges[','] = stateToNodeMap[StateInComma] - node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] - node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd] +func (b *PDAGraphBuilder) addEnds(node *PDA) { + node.TransitionEdges[','] = b.stateToNodeMap[StateInComma] + node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd] + node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd] } -func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { - node.TransitionEdges['"'] = stateToNodeMap[StateInString] +func (b *PDAGraphBuilder) addValueConnections(node *PDA) { + node.TransitionEdges['"'] = b.stateToNodeMap[StateInString] for _, r := range validNumberRunes { - node.TransitionEdges[r] = stateToNodeMap[StateInNumber] + node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber] } - node.TransitionEdges['t'] = stateToNodeMap[StateInBool] - node.TransitionEdges['f'] = stateToNodeMap[StateInBool] - node.TransitionEdges['n'] = stateToNodeMap[StateInNull] + // TODO(parthsareen): force the output and shift similar to structured outputs + node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool] + node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool] + node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull] } -// TODO: tough life fr. plz fix. -func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { - - // TODO; should come from top level - vocab := proc.GetVocabulary() - - decodedToks := make([]string, len(vocab.Values)) - for i := range vocab.Values { - token, err := proc.Decode([]int32{int32(i)}) - if err != nil { - return err - } - decodedToks[i] = token - } - - var err error - for _, node := range stateToNodeMap { - err = CreateMask(node, proc, decodedToks) - if err != nil { +func (b *PDAGraphBuilder) preComputeValidStates() error { + for _, node := range b.stateToNodeMap { + if err := b.CreateMask(node); err != nil { return err } } return nil } -func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error { - for i := range decodedToks { - token := decodedToks[i] +func (b *PDAGraphBuilder) CreateMask(node *PDA) error { + for i := range b.decodedToks { + token := b.decodedToks[i] // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON - if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { + if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { continue } - valid := true curNode := node + valid := true consumedSpecialRunes := make(map[rune]bool) - var err error for _, r := range token { - valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) - if err != nil { - return err - } - if !valid { + curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes) + if curNode == nil || !valid { break } } if valid { - // cur node allows skipping node.MaskTokenIDToNode[int32(i)] = curNode } } return nil } -// TODO: garbage interface plz fix -func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) { +func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) { if consumedSpecialRunes[r] { - return false, nil, nil + return nil, false } specialRune := slices.Contains(stringInvalidRunes, r) if specialRune { if curNode.State == StateInString || curNode.State == StateInObjectKey { - return false, nil, nil + return nil, false } } @@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) ( if nextNode, ok := curNode.TransitionEdges[r]; ok { if specialRune { if curNode.State == nextNode.State { - return false, nil, nil + return nil, false } consumedSpecialRunes[r] = true } - return true, nextNode, nil + return nextNode, true } // Check for sentinel value - if present, any rune is valid if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { - return true, nextNode, nil + return nextNode, true } - return false, nil, nil + return nil, false } diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index 97c58514b..592582375 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -11,17 +11,17 @@ import ( // TODO: safety in case of invalid json // TODO: interfaces to cleanup with return values +// TODO this interface shouldn't be the sampler - should just use Sampler +// TODO: add penalties for string \n stuff type PushdownSampler struct { - // stateful - curNode *PDANode - proc model.TextProcessor - stateToNodeMap map[JSONState]*PDANode - braceStack []rune - stateCounter uint32 + PDAGraphBuilder + curNode *PDA + braceStack []rune + stateCounter uint32 } // graph should be built once and reused per tokenizer -func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { +func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) { start := time.Now() fmt.Println("--------------------------------") @@ -32,27 +32,38 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { before := m.Alloc fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) - startNode, stateToNodeMap, err := BuildGraph(proc) - if err != nil { - panic(err) + vocab := proc.GetVocabulary() + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return nil, err + } + decodedToks[i] = token } - err = PreComputeValidStates(stateToNodeMap, proc) - if err != nil { - panic(err) + + gb := &PDAGraphBuilder{ + proc: proc, + decodedToks: decodedToks, } + + if err := gb.BuildGraph(); err != nil { + return nil, err + } + runtime.ReadMemStats(&m) after := m.Alloc fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Graph build time = %v\n", time.Since(start)) + // TODO: this can be simplified return &PushdownSampler{ - curNode: startNode, - proc: proc, - stateToNodeMap: stateToNodeMap, - braceStack: []rune{}, - stateCounter: 0, - } + curNode: gb.stateToNodeMap[StateStart], + PDAGraphBuilder: *gb, + braceStack: []rune{}, + stateCounter: 0, + }, nil } // TODO: need to add resampling logic if the first sample was not good @@ -66,14 +77,7 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) { // force finish if no braces left if len(s.braceStack) == 0 { s.curNode = NewPDANode(StateTerminate) - for i := range logits { - if s.proc.Is(uint32(i), model.SpecialEOS) { - logits[i] = 1.0 - } else { - logits[i] = math.Inf(-1) - } - } - return logits, nil + return forceFinish(s, logits) } logits, err := s.maskLogits(logits, s.curNode) @@ -82,18 +86,14 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) { } return logits, nil + case StateTerminate: + return forceFinish(s, logits) + case StateInObjectEnd: // force finish if no braces left if len(s.braceStack) == 0 { s.curNode = NewPDANode(StateTerminate) - for i := range logits { - if s.proc.Is(uint32(i), model.SpecialEOS) { - logits[i] = 1.0 - } else { - logits[i] = math.Inf(-1) - } - } - return logits, nil + return forceFinish(s, logits) } peek := s.braceStack[len(s.braceStack)-1] @@ -112,22 +112,13 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) { if peek == rune('[') { s.curNode = s.stateToNodeMap[StateInListComma] } + logits, err := s.maskLogits(logits, s.curNode) if err != nil { return nil, err } return logits, nil - case StateTerminate: - for i := range logits { - if s.proc.Is(uint32(i), model.SpecialEOS) { - logits[i] = 1.0 - } else { - logits[i] = math.Inf(-1) - } - } - return logits, nil - default: fmt.Println("masking logits current state", s.curNode.State) logits, err := s.maskLogits(logits, s.curNode) @@ -138,13 +129,24 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) { } } +func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) { + for i := range logits { + if s.proc.Is(uint32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = math.Inf(-1) + } + } + return logits, nil +} + func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { fmt.Println("current state - updating", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { return err } - fmt.Println(">>> mappedString", mappedString) + fmt.Printf(">>> mappedString: %q\n", mappedString) // TODO: should force closing for all braces - not doing square yet for _, r := range mappedString { @@ -198,7 +200,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { } // greedy sample + backtrack? -func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { +func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) { + // Create a new slice with same length as logits, initialized to -Inf maskedLogits := make([]float64, len(logits)) for i := range maskedLogits { @@ -215,4 +218,23 @@ func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64 return maskedLogits, nil } -// TODO: add penalties for string \n stuff +func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) { + maxLogit := math.Inf(-1) + maxIndex := -1 + + // Find the maximum logit value among valid tokens + for tokenID := range node.MaskTokenIDToNode { + if int(tokenID) < len(logits) && logits[tokenID] > maxLogit { + maxLogit = logits[tokenID] + maxIndex = int(tokenID) + } + } + + if maxIndex == -1 { + return nil, fmt.Errorf("no valid tokens found in mask") + } + + logits[0] = float64(maxIndex) + return logits, nil + // return maxIndex, nil +} diff --git a/sample/sample.go b/sample/sample.go index d873e7cee..98766b527 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -6,6 +6,8 @@ import ( "math" "slices" + pq "github.com/emirpasic/gods/v2/queues/priorityqueue" + "golang.org/x/exp/rand" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/stat/sampleuv" ) @@ -15,33 +17,34 @@ type Transform interface { } type Sampler interface { - Sample([]float64) (int, error) + Sample([]float32, ...Transform) (int, error) } -type SamplerConfig struct { - transforms []Transform - sampler Sampler -} - -// NewSampler creates a sampler with the given transforms and sampling method -func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig { - return &SamplerConfig{ - transforms: transforms, - sampler: sampler, +// TODO(parthsareen): potentially cache softmax values +func softmax(logits []float64) []float64 { + var sum float64 + tt := make([]float64, len(logits)) + for i, v := range logits { + tt[i] = math.Exp(v) + sum += tt[i] } + floats.Scale(1/sum, tt) + return tt } type Temperature float64 func (t Temperature) Apply(logits []float64) ([]float64, error) { + if t == 0 { + return nil, errors.New("use Greedy sampler instead of Temperature(0)") + } if t < 0 || t > 2 { return nil, errors.New("temperature must be between 0 and 2") } + temp := math.Max(float64(t), 1e-7) // subtracting max logit to avoid under/overflow - maxLogit := floats.Max(logits) - - temp := math.Max(float64(t), 1e-7) + maxLogit := slices.Max(logits) for i := range logits { logits[i] = (logits[i] - maxLogit) / temp } @@ -49,52 +52,41 @@ func (t Temperature) Apply(logits []float64) ([]float64, error) { return logits, nil } -type softmax struct{} - -func Softmax() Transform { - return softmax{} +type logitMap struct { + index int + logit float64 } -func (softmax) Apply(logits []float64) ([]float64, error) { - return computeSoftmax(logits), nil -} - -// TODO: cache softmax values -func computeSoftmax(logits []float64) []float64 { - copiedLogits := make([]float64, len(logits)) - copy(copiedLogits, logits) - for i := range copiedLogits { - copiedLogits[i] = math.Exp(copiedLogits[i]) - } - - floatSum := floats.Sum(copiedLogits) - floats.Scale(1.0/floatSum, copiedLogits) - - return copiedLogits +func logitMapComparator(a, b logitMap) int { + return -cmp.Compare(a.logit, b.logit) } type TopK int +// TODO(parthsareen): avoid having to check all logits after this transform func (k TopK) Apply(logits []float64) ([]float64, error) { if k <= 0 { - return nil, errors.New("k must be positive") + return nil, errors.New("k must be greater than 0") } if int(k) >= len(logits) { return logits, nil } - indices := make([]int, len(logits)) - for i := range indices { - indices[i] = i + q := pq.NewWith(logitMapComparator) + for i, logit := range logits { + q.Enqueue(logitMap{index: i, logit: logit}) } - // sort in descending order - slices.SortFunc(indices, func(i, j int) int { - return cmp.Compare(logits[j], logits[i]) - }) + validLogits := make(map[int]float64) + for range k { + logitMap, _ := q.Dequeue() + validLogits[logitMap.index] = logitMap.logit + } - for _, idx := range indices[k:] { - logits[idx] = math.Inf(-1) + for i := range logits { + if _, ok := validLogits[i]; !ok { + logits[i] = math.Inf(-1) + } } return logits, nil @@ -107,8 +99,7 @@ func (p TopP) Apply(logits []float64) ([]float64, error) { return nil, errors.New("p must be between 0 and 1") } - probs := computeSoftmax(logits) - + probs := softmax(logits) indices := make([]int, len(probs)) for i := range indices { indices[i] = i @@ -139,17 +130,11 @@ func (p MinP) Apply(logits []float64) ([]float64, error) { return nil, errors.New("p must be between 0 and 1") } - probs := computeSoftmax(logits) - copiedProbs := make([]float64, len(probs)) - copy(copiedProbs, probs) + probs := softmax(logits) + threshold := slices.Max(probs) * float64(p) - slices.Sort(copiedProbs) - - maxProb := copiedProbs[len(copiedProbs)-1] - probThreshold := float64(p) * maxProb - - for i := range probs { - if probs[i] < probThreshold { + for i, prob := range probs { + if prob < threshold { logits[i] = math.Inf(-1) } } @@ -157,18 +142,35 @@ func (p MinP) Apply(logits []float64) ([]float64, error) { return logits, nil } -type weighed struct{} - -func Weighed() Sampler { - return weighed{} +type weighted struct { + src rand.Source } -// should return single value -func (s weighed) Sample(logits []float64) (int, error) { +func Weighted(seed *int64) Sampler { + var src rand.Source + if seed != nil { + src = rand.NewSource(uint64(*seed)) + } + return weighted{src: src} +} + +func (s weighted) Sample(logits []float32, transforms ...Transform) (int, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + var err error + for _, t := range transforms { + logits64, err = t.Apply(logits64) + if err != nil { + return -1, err + } + } + logitsCopy := make([]float64, 0, len(logits)) indices := make([]int, 0, len(logits)) - // the uv sampler does not support NaN values - for i, logit := range logits { + for i, logit := range logits64 { if !math.IsInf(logit, -1) { logitsCopy = append(logitsCopy, logit) indices = append(indices, i) @@ -176,38 +178,13 @@ func (s weighed) Sample(logits []float64) (int, error) { } if len(logitsCopy) == 0 { - return -1, errors.New("no valid tokens found") + return -1, errors.New("no valid logits found for weighed sampling") } - softmax := computeSoftmax(logitsCopy) - w := sampleuv.NewWeighted(softmax, nil) + probs := softmax(logitsCopy) + w := sampleuv.NewWeighted(probs, s.src) if idx, ok := w.Take(); ok { - // returns the token ID return indices[idx], nil } - return -1, errors.New("weighed sampler failed") -} - -// Sample applies transforms and samples a token ID -func (s *SamplerConfig) Sample(input []float32) (int, error) { - logits := make([]float64, len(input)) - for i, v := range input { - logits[i] = float64(v) - } - - var err error - for _, t := range s.transforms { - if t == Temperature(0) { - // early return with greedy if temperature is 0 - s.sampler = Greedy() - break - } - - logits, err = t.Apply(logits) - if err != nil { - return -1, err - } - } - - return s.sampler.Sample(logits) + return -1, errors.New("weighed sampler failed, no valid token found") } diff --git a/sample/sample_test.go b/sample/sample_test.go index 78c7209e7..635e2f878 100644 --- a/sample/sample_test.go +++ b/sample/sample_test.go @@ -3,116 +3,129 @@ package sample import ( "fmt" "math" - "slices" + "math/rand/v2" "testing" - "gonum.org/v1/gonum/floats" + "github.com/google/go-cmp/cmp" ) func TestTemperature(t *testing.T) { - logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) if err != nil { - t.Fatal(err) + t.Error(err) + return } - want := []float64{-14, -12, -10, -8, -6, -4, 0} - if !floats.Equal(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) + want := []float64{-4, -10, 0, -14, -6, -12, -8} + if diff := cmp.Diff(want, logits); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) } - if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { - t.Fatalf("expected error for temperature=-1, got %v", logits) + logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Errorf("expected error for temperature=-1, got %v", logits) } - if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { - t.Fatalf("expected error for temperature=2.1, got %v", logits) + logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Errorf("expected error for temperature=0, got %v", logits) + } + logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err == nil { + t.Errorf("expected error for temperature=2.1, got %v", logits) } } func TestSoftmax(t *testing.T) { - probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - if err != nil { - t.Fatal(err) - } + probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} - if !floats.Equal(probs, expectedProbs) { - t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs) + if diff := cmp.Diff(expectedProbs, probs); diff != "" { + t.Errorf("probs mismatch (-want +got):\n%s", diff) } } func TestTopK(t *testing.T) { logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { - t.Fatal(err) + t.Error(err) + return } expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} - if !floats.Same(logits, expectedlogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + if diff := cmp.Diff(expectedlogits, logits); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) } - logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + + _, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { - t.Fatalf("expected error for k=0, got %v", logits) + t.Errorf("expected error for k=0, got %v", err) } logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { - t.Fatal(err) + t.Error(err) + return } expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4} - if !floats.Same(logits, expectedlogits) { - t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + if diff := cmp.Diff(expectedlogits, logits); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) } } func TestTopP(t *testing.T) { logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { - t.Fatal(err) + t.Error(err) + return } want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} - if !floats.Same(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) + if diff := cmp.Diff(want, logits); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) } - logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + + _, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { - t.Fatalf("expected error for p=1.0, got %v", logits) + t.Error("expected error for p=1.0") } - logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + _, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { - t.Fatalf("expected error for p=0.0, got %v", logits) + t.Error("expected error for p=0.0") } } func TestMinP(t *testing.T) { - logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) if err != nil { - t.Fatal(err) + t.Error(err) + return } - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4} - if !floats.Same(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} + if diff := cmp.Diff(want, logits); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) } - logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + + _, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err == nil { - t.Fatalf("expected error for p=1.0, got %v", logits) + t.Error("expected error for p=1.0") } - logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + _, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err == nil { - t.Fatalf("expected error for p=0.0, got %v", logits) + t.Error("expected error for p=0.0") } } func TestWeighed(t *testing.T) { - idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)}) + idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) if err != nil { - t.Fatal(err) + t.Error(err) + return } want := 1 - if idx != want { - t.Fatalf("got: %v, want: %v", idx, want) + if diff := cmp.Diff(want, idx); diff != "" { + t.Errorf("index mismatch (-want +got):\n%s", diff) } - idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)}) + + idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) if err == nil { - t.Fatalf("expected error for no valid tokens, got %v", idx) + t.Error("expected error for no valid tokens, got index", idx) } } @@ -132,27 +145,32 @@ func TestSample(t *testing.T) { id: 3, callOrder: &callOrder, } - sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy()) - got, err := sampler.Sample(input) + got, err := Greedy().Sample(input, mock1, mock2, mock3) if err != nil { - t.Fatal(err) - } - - if !slices.Equal(callOrder, []int{1, 2, 3}) { - t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3}) + t.Error(err) + return } want := 3 // Greedy sampler should pick highest logit - if got != want { - t.Errorf("got %v, want %v", got, want) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("sampled index mismatch (-want +got):\n%s", diff) + } + + _, err = Weighted(nil).Sample(input, mock1, mock2, mock3) + if err != nil { + t.Error(err) + return + } + wantOrder := []int{1, 2, 3} + if diff := cmp.Diff(wantOrder, callOrder); diff != "" { + t.Errorf("call order mismatch (-want +got):\n%s", diff) } errMock := &testTransform{ returnErr: fmt.Errorf("mock error"), } - sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy()) - _, err = sampler.Sample(input) + _, err = Weighted(nil).Sample(input, mock1, errMock, mock2) if err == nil { t.Error("Expected error from sampler") } @@ -174,14 +192,51 @@ func (ts *testTransform) Apply(logits []float64) ([]float64, error) { return logits, nil } -func TestSampleTemperatureZero(t *testing.T) { - sampler := NewSampler([]Transform{Temperature(0)}, Greedy()) - got, err := sampler.Sample([]float32{1, 2, 3, 4}) - if err != nil { - t.Fatal(err) +func BenchmarkTransform(b *testing.B) { + transforms := map[string]Transform{ + "Temperature": Temperature(0.5), + "TopK": TopK(10), + "TopP": TopP(0.9), + "MinP": MinP(0.2), } - want := 3 // Greedy sampler should pick highest logit index - if got != want { - t.Fatalf("got: %v, want: %v", got, want) + + logits := make([]float64, 1<<16) + for i := range logits { + logits[i] = rand.Float64() + } + + for name, transform := range transforms { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + _, err := transform.Apply(logits) + if err != nil { + b.Error(err) + } + } + }) + } +} + +func BenchmarkSample(b *testing.B) { + samplers := map[string]Sampler{ + "Greedy": Greedy(), + "Weighted": Weighted(nil), + } + + logits := make([]float32, 1<<16) + for i := range logits { + logits[i] = rand.Float32() + } + + for name, s := range samplers { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + if _, err := s.Sample(logits); err != nil { + b.Error(err) + } + } + }) } } diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 91c1d82de..adf5ce996 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -8,27 +8,45 @@ import ( "github.com/ollama/ollama/model" ) -type SOSampler struct { +type JSONSampler struct { schema *Schema propIdx int - propToNodeMap map[string]*PDANode + propToNodeMap map[string]*PDA pdaSampler *PushdownSampler decodedToks []string } -func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { - pdaSampler := NewPushdownSampler(proc) +func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) { + pdaSampler, err := NewPushdownSampler(proc) + if err != nil { + return nil, err + } - so := &SOSampler{ + if schema == nil { + return &JSONSampler{ + schema: nil, + propIdx: -1, + propToNodeMap: nil, + pdaSampler: pdaSampler, + }, nil + } + + fmt.Println("schema not nil") + so := &JSONSampler{ schema: schema, propIdx: -1, - propToNodeMap: make(map[string]*PDANode), + propToNodeMap: make(map[string]*PDA), pdaSampler: pdaSampler, } so.schemaToGraph() - // This is prob slow + // Benchmark token decoding + start := time.Now() + var m runtime.MemStats + runtime.ReadMemStats(&m) + before := m.Alloc + vocab := proc.GetVocabulary() decodedToks := make([]string, len(vocab.Values)) for i := range vocab.Values { @@ -40,14 +58,18 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) } so.decodedToks = decodedToks + runtime.ReadMemStats(&m) + after := m.Alloc + fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Token decode time = %v\n", time.Since(start)) + fmt.Println("--------------------------------") fmt.Println("SOSampler") fmt.Println("--------------------------------") // Benchmark this section - start := time.Now() - var m runtime.MemStats + start = time.Now() runtime.ReadMemStats(&m) - before := m.Alloc + before = m.Alloc // TODO: still messed up // TODO: recursion use case @@ -57,12 +79,12 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) // propName -> node curState := node.State fromNode := node - CreateMask(fromNode, proc, decodedToks) + so.pdaSampler.CreateMask(fromNode) for curState == StateInStructuredKey { // there is only one edge for r, toNode := range fromNode.TransitionEdges { // fmt.Println("rune", r, "edge", toNode.State) - CreateMask(toNode, proc, decodedToks) + so.pdaSampler.CreateMask(toNode) fmt.Printf("created mask for %c\n", r) curState = toNode.State fmt.Println("next state", curState) @@ -73,7 +95,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) } runtime.ReadMemStats(&m) - after := m.Alloc + after = m.Alloc fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Mask creation time = %v\n", time.Since(start)) fmt.Println("--------------------------------") @@ -81,7 +103,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) return so, nil } -func (s *SOSampler) schemaToGraph() { +func (s *JSONSampler) schemaToGraph() { schemaType := s.schema.EffectiveType() switch schemaType { case "object": @@ -91,18 +113,18 @@ func (s *SOSampler) schemaToGraph() { for _, prop := range s.schema.Properties { // name of key name := prop.Name - keyNode := &PDANode{ + keyNode := &PDA{ State: StateInStructuredKey, // this is unchanging, will impact sampling - TransitionEdges: make(map[rune]*PDANode), - MaskTokenIDToNode: make(map[int32]*PDANode), + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), } prevNode := keyNode for _, r := range name { - runeNode := &PDANode{ + runeNode := &PDA{ State: StateInStructuredKey, // this is unchanging, will impact sampling - TransitionEdges: make(map[rune]*PDANode), - MaskTokenIDToNode: make(map[int32]*PDANode), + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), } fmt.Println("runeNode created", runeNode.State) fmt.Printf("runeNode created %c\n", r) @@ -117,9 +139,14 @@ func (s *SOSampler) schemaToGraph() { fmt.Println("name", name, "keyNode", keyNode.State) } } + // TODO: do values + recursion } -func (s *SOSampler) Apply(logits []float64) ([]float64, error) { +func (s *JSONSampler) Apply(logits []float64) ([]float64, error) { + if s.schema == nil { + return s.pdaSampler.Apply(logits) + } + switch s.pdaSampler.curNode.State { // doesnt account for multi rune case case StateInObjectKey: @@ -148,17 +175,18 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) { // todo: if i incremenet propidx then i know im in last value as well switch s.pdaSampler.curNode.State { case StateInObjectEnd: - fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State) - s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode) + fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State) + s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA) s.pdaSampler.curNode = NewPDANode(StateTerminate) s.propIdx++ + // TODO: this needs to be optimized in some way, computing mask on the fly is expensive case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd: fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State) delete(s.pdaSampler.curNode.TransitionEdges, ',') - s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode) + s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA) - CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks) + s.pdaSampler.CreateMask(s.pdaSampler.curNode) s.propIdx++ } } @@ -167,12 +195,17 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) { } -func (s *SOSampler) UpdateState(tokenSlice []int32) error { +func (s *JSONSampler) UpdateState(tokenSlice []int32) error { err := s.pdaSampler.UpdateState(tokenSlice) if err != nil { return err } + if s.schema == nil { + // Don't need to update state for unconstrained JSON sampling + return nil + } + switch s.pdaSampler.curNode.State { case StateInObjectKey: s.propIdx++