err handling + fixing scope issue

This commit is contained in:
ParthSareen 2025-02-11 19:29:48 -08:00
parent aa6d5151df
commit ffd6428758
2 changed files with 29 additions and 20 deletions

View File

@ -1,6 +1,7 @@
package sample package sample
import ( import (
"fmt"
"slices" "slices"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
@ -34,7 +35,7 @@ Key JSON rules to consider:
*/ */
// TODO: / should be valid but an escape character // TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'} var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@ -109,12 +110,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
b.addValueConnections(stateToNodeMap[StateInColon]) addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
// Leads to a value // Leads to a value
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
b.addValueConnections(stateToNodeMap[StateInSpaceToValue]) addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// Values // Values
@ -123,7 +124,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd] stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node // String end node
b.addEnds(stateToNodeMap[StateInStringEnd]) addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -132,7 +133,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
for _, r := range validNumberRunes { for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
} }
b.addEnds(stateToNodeMap[StateInNumber]) addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -150,13 +151,13 @@ func (b *PDAGraphBuilder) BuildGraph() error {
// empty list // empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
b.addValueConnections(stateToNodeMap[StateInList]) addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node // null node
for _, r := range validNullRunes { for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull] stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
} }
b.addEnds(stateToNodeMap[StateInNull]) addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -165,7 +166,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
b.addValueConnections(stateToNodeMap[StateInListComma]) addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end // list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
@ -178,7 +179,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
} }
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
b.addEnds(stateToNodeMap[StateInBool]) addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -201,21 +202,21 @@ func (b *PDAGraphBuilder) BuildGraph() error {
return nil return nil
} }
func (b *PDAGraphBuilder) addEnds(node *PDA) { func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges[','] = b.stateToNodeMap[StateInComma] node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd] node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd] node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
} }
func (b *PDAGraphBuilder) addValueConnections(node *PDA) { func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges['"'] = b.stateToNodeMap[StateInString] node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes { for _, r := range validNumberRunes {
node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber] node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
} }
// TODO(parthsareen): force the output and shift similar to structured outputs // TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool] node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool] node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull] node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
} }
func (b *PDAGraphBuilder) preComputeValidStates() error { func (b *PDAGraphBuilder) preComputeValidStates() error {
@ -228,6 +229,9 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
} }
func (b *PDAGraphBuilder) CreateMask(node *PDA) error { func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range b.decodedToks { for i := range b.decodedToks {
token := b.decodedToks[i] token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
@ -264,6 +268,7 @@ func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA
// Check for specific rune transition // Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok { if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune { if specialRune {
if curNode.State == nextNode.State { if curNode.State == nextNode.State {
return nil, false return nil, false

View File

@ -17,9 +17,13 @@ type JSONSampler struct {
} }
func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) { func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
if proc == nil {
return nil, fmt.Errorf("TextProcessor cannot be nil")
}
pdaSampler, err := NewPushdownSampler(proc) pdaSampler, err := NewPushdownSampler(proc)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
} }
if schema == nil { if schema == nil {