err handling + fixing scope issue
This commit is contained in:
parent
aa6d5151df
commit
ffd6428758
@ -1,6 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
@ -34,7 +35,7 @@ Key JSON rules to consider:
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// TODO: / should be valid but an escape character
|
// TODO: / should be valid but an escape character
|
||||||
var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||||
|
|
||||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
@ -109,12 +110,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
b.addValueConnections(stateToNodeMap[StateInColon])
|
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||||
|
|
||||||
// Leads to a value
|
// Leads to a value
|
||||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
|
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
// Values
|
// Values
|
||||||
@ -123,7 +124,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||||
|
|
||||||
// String end node
|
// String end node
|
||||||
b.addEnds(stateToNodeMap[StateInStringEnd])
|
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||||
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
@ -132,7 +133,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
for _, r := range validNumberRunes {
|
for _, r := range validNumberRunes {
|
||||||
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
}
|
}
|
||||||
b.addEnds(stateToNodeMap[StateInNumber])
|
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||||
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
@ -150,13 +151,13 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
|
|
||||||
// empty list
|
// empty list
|
||||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
b.addValueConnections(stateToNodeMap[StateInList])
|
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||||
|
|
||||||
// null node
|
// null node
|
||||||
for _, r := range validNullRunes {
|
for _, r := range validNullRunes {
|
||||||
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||||
}
|
}
|
||||||
b.addEnds(stateToNodeMap[StateInNull])
|
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||||
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
@ -165,7 +166,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||||
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
b.addValueConnections(stateToNodeMap[StateInListComma])
|
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||||
|
|
||||||
// list object end
|
// list object end
|
||||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||||
@ -178,7 +179,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||||
}
|
}
|
||||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
b.addEnds(stateToNodeMap[StateInBool])
|
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||||
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
@ -201,21 +202,21 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *PDAGraphBuilder) addEnds(node *PDA) {
|
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
|
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
|
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
|
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
|
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
|
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||||
for _, r := range validNumberRunes {
|
for _, r := range validNumberRunes {
|
||||||
node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
|
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
}
|
}
|
||||||
// TODO(parthsareen): force the output and shift similar to structured outputs
|
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||||
node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
|
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||||
node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
|
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||||
node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
|
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||||
@ -228,6 +229,9 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
}
|
||||||
for i := range b.decodedToks {
|
for i := range b.decodedToks {
|
||||||
token := b.decodedToks[i]
|
token := b.decodedToks[i]
|
||||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
@ -264,6 +268,7 @@ func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA
|
|||||||
|
|
||||||
// Check for specific rune transition
|
// Check for specific rune transition
|
||||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
// fmt.Println("next node", nextNode)
|
||||||
if specialRune {
|
if specialRune {
|
||||||
if curNode.State == nextNode.State {
|
if curNode.State == nextNode.State {
|
||||||
return nil, false
|
return nil, false
|
||||||
|
@ -17,9 +17,13 @@ type JSONSampler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
|
func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
|
||||||
|
if proc == nil {
|
||||||
|
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
pdaSampler, err := NewPushdownSampler(proc)
|
pdaSampler, err := NewPushdownSampler(proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema == nil {
|
if schema == nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user