wip
This commit is contained in:
parent
198fde82aa
commit
c56a8b7749
@ -104,6 +104,8 @@ func temp() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||||
|
var stringBuffer string
|
||||||
var offset int
|
var offset int
|
||||||
for range args.n {
|
for range args.n {
|
||||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||||
@ -118,7 +120,10 @@ func temp() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// do sampling
|
// do sampling
|
||||||
f64s, err = sample.Sample(f64s, sample.Greedy())
|
// []ints back
|
||||||
|
// ints map to sampled logits
|
||||||
|
f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -129,6 +134,7 @@ func temp() error {
|
|||||||
outputIDs = append(outputIDs, int32(f64))
|
outputIDs = append(outputIDs, int32(f64))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pdaSampler.UpdateState(outputIDs)
|
||||||
|
|
||||||
if len(outputIDs) == 0 {
|
if len(outputIDs) == 0 {
|
||||||
break
|
break
|
||||||
@ -141,8 +147,9 @@ func temp() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Print(s)
|
// fmt.Print(s)
|
||||||
|
stringBuffer += s
|
||||||
|
fmt.Println("--- stringBuffer", stringBuffer)
|
||||||
inputIDs = append(inputIDs, outputIDs...)
|
inputIDs = append(inputIDs, outputIDs...)
|
||||||
if args.cache {
|
if args.cache {
|
||||||
offset = len(inputIDs) - 1
|
offset = len(inputIDs) - 1
|
||||||
|
1
model/cmd/test.go
Normal file
1
model/cmd/test.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package main
|
@ -21,6 +21,7 @@ type TextProcessor interface {
|
|||||||
Encode(string) ([]int32, error)
|
Encode(string) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(uint32, Special) bool
|
Is(uint32, Special) bool
|
||||||
|
GetVocabulary() *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
@ -104,6 +105,10 @@ type BytePairEncoding struct {
|
|||||||
*Vocabulary
|
*Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
|
||||||
|
return bpe.Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
||||||
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONState int
|
type JSONState int
|
||||||
@ -136,219 +132,3 @@ func (s JSONState) String() string {
|
|||||||
return fmt.Sprintf("Unknown state: %d", s)
|
return fmt.Sprintf("Unknown state: %d", s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type JSONSampler struct {
|
|
||||||
curNode *Node
|
|
||||||
proc model.TextProcessor
|
|
||||||
stack []*Node
|
|
||||||
bracketCounter int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
|
||||||
// fmt.Println("Creating new JSON sampler")
|
|
||||||
startNode, err := buildStateMachine(proc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
js := &JSONSampler{
|
|
||||||
curNode: startNode,
|
|
||||||
proc: proc,
|
|
||||||
stack: []*Node{},
|
|
||||||
bracketCounter: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
return js, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTokenSubset(subset, superset []int32) bool {
|
|
||||||
freq1 := make(map[int32]int)
|
|
||||||
freq2 := make(map[int32]int)
|
|
||||||
|
|
||||||
for _, v := range subset {
|
|
||||||
freq1[v]++
|
|
||||||
}
|
|
||||||
for _, v := range superset {
|
|
||||||
freq2[v]++
|
|
||||||
}
|
|
||||||
isSubset := true
|
|
||||||
for k, count1 := range freq1 {
|
|
||||||
count2 := freq2[k]
|
|
||||||
if count1 > count2 {
|
|
||||||
isSubset = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return isSubset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
|
||||||
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
|
||||||
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
|
||||||
|
|
||||||
// fmt.Println("tokenSlice", tokenSlice)
|
|
||||||
// todo: account for strings here
|
|
||||||
|
|
||||||
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only move to terminate state if stack is empty
|
|
||||||
if s.curNode.State == StateInObjectEnd {
|
|
||||||
fmt.Println("debug: node.State", s.curNode.State)
|
|
||||||
if len(s.stack) > 0 {
|
|
||||||
s.stack = s.stack[:len(s.stack)-1]
|
|
||||||
fmt.Println("popped and cur state", s.curNode.State)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for node, edge := range s.curNode.TransitionEdges {
|
|
||||||
for _, validToken := range edge {
|
|
||||||
if isTokenSubset(tokenSlice, validToken) {
|
|
||||||
s.curNode = node
|
|
||||||
for _, token := range objectTokens {
|
|
||||||
if isTokenSubset(tokenSlice, token) {
|
|
||||||
fmt.Println("Appending to stack", s.curNode.State)
|
|
||||||
s.stack = append(s.stack, s.curNode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for node, edge := range s.curNode.TransitionEdges {
|
|
||||||
for _, validToken := range edge {
|
|
||||||
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
|
|
||||||
s.curNode = node
|
|
||||||
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Println("invalid token ", tokenSlice)
|
|
||||||
dec, err := s.proc.Decode(tokenSlice)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Println("decoded token ", dec)
|
|
||||||
return errors.New("invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
|
||||||
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
|
|
||||||
var err error
|
|
||||||
|
|
||||||
switch s.curNode.State {
|
|
||||||
case StateTerminate:
|
|
||||||
for i := range logits {
|
|
||||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
|
||||||
logits[i] = 1.0
|
|
||||||
} else {
|
|
||||||
logits[i] = math.NaN()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
|
|
||||||
case StateInInt:
|
|
||||||
validStates := []int32{}
|
|
||||||
minus, err := s.proc.Encode("-")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
digits := make([][]int32, 10)
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Allow "-" and digits 0-9 at start
|
|
||||||
for i := range logits {
|
|
||||||
for _, d := range digits {
|
|
||||||
if len(d) == 1 && int32(i) == d[0] {
|
|
||||||
validStates = append(validStates, int32(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(minus) == 1 && int32(i) == minus[0] {
|
|
||||||
validStates = append(validStates, int32(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
|
|
||||||
case StateInString:
|
|
||||||
penalizeNewlineVariants := []string{"\n", " \"\n"}
|
|
||||||
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
|
|
||||||
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
validStates := getValidStates(s.curNode)
|
|
||||||
logits, err = s.maskLogits(logits, validStates)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
validStates := getValidStates(s.curNode)
|
|
||||||
logits, err = s.maskLogits(logits, validStates)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getValidStates(node *Node) []int32 {
|
|
||||||
validStates := []int32{}
|
|
||||||
for _, edge := range node.TransitionEdges {
|
|
||||||
for _, token := range edge {
|
|
||||||
validStates = append(validStates, token...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return validStates
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
|
||||||
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
|
||||||
// todo: this can prob be more efficient
|
|
||||||
for i := range logits {
|
|
||||||
isValid := false
|
|
||||||
for _, token := range validStates {
|
|
||||||
if token == -1 {
|
|
||||||
// fmt.Println("Found sentinel token, returning unmasked logits")
|
|
||||||
return logits, nil
|
|
||||||
}
|
|
||||||
if i == int(token) {
|
|
||||||
// fmt.Printf("Found valid token: %d\n", token)
|
|
||||||
isValid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !isValid {
|
|
||||||
logits[i] = math.NaN()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
|
|
||||||
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
|
|
||||||
for i := range logits {
|
|
||||||
for _, token := range tokensToMask {
|
|
||||||
for _, chunked := range token {
|
|
||||||
if int(chunked) == i {
|
|
||||||
logits[i] = math.NaN()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
}
|
|
||||||
|
296
sample/hid.txt
Normal file
296
sample/hid.txt
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||||
|
|
||||||
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
|
|
||||||
|
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||||
|
|
||||||
|
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||||
|
|
||||||
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||||
|
|
||||||
|
type PDANode struct {
|
||||||
|
State JSONState
|
||||||
|
TransitionEdges map[rune]*PDANode
|
||||||
|
MaskTokenIDToNode map[int32]JSONState
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPDANode(state JSONState) *PDANode {
|
||||||
|
return &PDANode{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*PDANode),
|
||||||
|
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||||
|
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||||
|
|
||||||
|
startNode := NewPDANode(StateStart)
|
||||||
|
stateToNodeMap[StateStart] = startNode
|
||||||
|
|
||||||
|
objNode := NewPDANode(StateInObject)
|
||||||
|
stateToNodeMap[StateInObject] = objNode
|
||||||
|
|
||||||
|
objEndNode := NewPDANode(StateInObjectEnd)
|
||||||
|
stateToNodeMap[StateInObjectEnd] = objEndNode
|
||||||
|
|
||||||
|
objKeyNode := NewPDANode(StateInObjectKey)
|
||||||
|
stateToNodeMap[StateInObjectKey] = objKeyNode
|
||||||
|
|
||||||
|
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
|
||||||
|
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
|
||||||
|
|
||||||
|
colonNode := NewPDANode(StateInColon)
|
||||||
|
stateToNodeMap[StateInColon] = colonNode
|
||||||
|
|
||||||
|
commaNode := NewPDANode(StateInComma)
|
||||||
|
stateToNodeMap[StateInComma] = commaNode
|
||||||
|
|
||||||
|
newlineNode := NewPDANode(StateInNewline)
|
||||||
|
stateToNodeMap[StateInNewline] = newlineNode
|
||||||
|
|
||||||
|
spaceNode := NewPDANode(StateInSpace)
|
||||||
|
stateToNodeMap[StateInSpace] = spaceNode
|
||||||
|
|
||||||
|
spaceObjNode := NewPDANode(StateInObjSpace)
|
||||||
|
stateToNodeMap[StateInObjSpace] = spaceObjNode
|
||||||
|
|
||||||
|
tabNode := NewPDANode(StateInTab)
|
||||||
|
stateToNodeMap[StateInTab] = tabNode
|
||||||
|
|
||||||
|
stringNode := NewPDANode(StateInString)
|
||||||
|
stateToNodeMap[StateInString] = stringNode
|
||||||
|
|
||||||
|
stringEndNode := NewPDANode(StateInStringEnd)
|
||||||
|
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||||
|
|
||||||
|
listNode := NewPDANode(StateInList)
|
||||||
|
stateToNodeMap[StateInList] = listNode
|
||||||
|
|
||||||
|
listCommaNode := NewPDANode(StateInListComma)
|
||||||
|
stateToNodeMap[StateInListComma] = listCommaNode
|
||||||
|
|
||||||
|
listEndNode := NewPDANode(StateListEnd)
|
||||||
|
stateToNodeMap[StateListEnd] = listEndNode
|
||||||
|
|
||||||
|
numberNode := NewPDANode(StateInNumber)
|
||||||
|
stateToNodeMap[StateInNumber] = numberNode
|
||||||
|
|
||||||
|
boolNode := NewPDANode(StateInBool)
|
||||||
|
stateToNodeMap[StateInBool] = boolNode
|
||||||
|
|
||||||
|
nullNode := NewPDANode(StateInNull)
|
||||||
|
stateToNodeMap[StateInNull] = nullNode
|
||||||
|
|
||||||
|
// Defined with structured outputs only
|
||||||
|
intNode := NewPDANode(StateInInt)
|
||||||
|
stateToNodeMap[StateInInt] = intNode
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// consider adding a node to just point to values, could be good to compute that
|
||||||
|
// mask rather than many different nodes
|
||||||
|
|
||||||
|
// Connect nodes
|
||||||
|
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||||
|
startNode.TransitionEdges['{'] = objNode
|
||||||
|
|
||||||
|
objNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
objNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
// objNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
|
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
newlineNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
|
tabNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
// tabNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
|
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||||
|
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||||
|
|
||||||
|
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||||
|
objEndNode.TransitionEdges[' '] = spaceNode
|
||||||
|
|
||||||
|
// where values should be
|
||||||
|
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||||
|
colonNode.TransitionEdges[' '] = spaceNode
|
||||||
|
|
||||||
|
// Leads to a value
|
||||||
|
spaceNode.TransitionEdges['"'] = stringNode
|
||||||
|
spaceNode.TransitionEdges['['] = listNode
|
||||||
|
spaceNode.TransitionEdges['{'] = objNode
|
||||||
|
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
spaceNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Values
|
||||||
|
// string node
|
||||||
|
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||||
|
stringNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
|
||||||
|
stringEndNode.TransitionEdges[','] = commaNode
|
||||||
|
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
stringEndNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
|
// number node
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
numberNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
numberNode.TransitionEdges[','] = commaNode
|
||||||
|
numberNode.TransitionEdges['}'] = objEndNode
|
||||||
|
numberNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// list node
|
||||||
|
listNode.TransitionEdges[','] = commaNode
|
||||||
|
listNode.TransitionEdges['"'] = stringNode
|
||||||
|
// squash states to a value
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
listNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
listNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
listNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// null node
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
nullNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
nullNode.TransitionEdges[','] = commaNode
|
||||||
|
nullNode.TransitionEdges['}'] = objEndNode
|
||||||
|
nullNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
// list comma
|
||||||
|
// should point to values
|
||||||
|
listCommaNode.TransitionEdges['"'] = stringNode
|
||||||
|
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||||
|
listCommaNode.TransitionEdges['{'] = objNode
|
||||||
|
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = numberNode
|
||||||
|
}
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
listCommaNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// bool node
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
boolNode.TransitionEdges['}'] = objEndNode
|
||||||
|
boolNode.TransitionEdges[']'] = listEndNode
|
||||||
|
boolNode.TransitionEdges[','] = commaNode
|
||||||
|
|
||||||
|
listEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
listEndNode.TransitionEdges[','] = commaNode
|
||||||
|
|
||||||
|
commaNode.TransitionEdges['{'] = objNode
|
||||||
|
commaNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
commaNode.TransitionEdges['\t'] = tabNode
|
||||||
|
commaNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||||
|
|
||||||
|
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
|
||||||
|
return startNode, stateToNodeMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||||
|
|
||||||
|
vocab := proc.GetVocabulary()
|
||||||
|
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, node := range stateToNodeMap {
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token := decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
valid := true
|
||||||
|
curNode := node
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
// fmt.Println("special rune", r, "consumed")
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return true, nextNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return true, nextNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
@ -1,104 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JSONState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
StateStart JSONState = iota // Initial state
|
|
||||||
StateInObject // Inside an object {}
|
|
||||||
StateInArray // Inside an array []
|
|
||||||
StateInString // Inside a string ""
|
|
||||||
StateAfterKey // After object key, expecting :
|
|
||||||
StateAfterColon // After :, expecting value
|
|
||||||
StateAfterValue // After value, expecting , or closing bracket
|
|
||||||
StateDone // JSON parsing complete
|
|
||||||
)
|
|
||||||
|
|
||||||
type JSONSampler struct {
|
|
||||||
state JSONState
|
|
||||||
stack []string
|
|
||||||
proc model.TextProcessor
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
|
|
||||||
return &JSONSampler{
|
|
||||||
state: StateStart,
|
|
||||||
proc: proc,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
|
||||||
// Pre-decode valid tokens for current state
|
|
||||||
validTokens := make(map[uint32]bool)
|
|
||||||
|
|
||||||
// Always allow EOS token in any state
|
|
||||||
// TODO: Check for other special tokens if needed
|
|
||||||
for i := range logits {
|
|
||||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
|
||||||
validTokens[uint32(i)] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build set of valid tokens based on current state
|
|
||||||
switch s.state {
|
|
||||||
case StateStart:
|
|
||||||
// Only allow opening brace
|
|
||||||
for i := range logits {
|
|
||||||
text, err := s.proc.Decode([]int32{int32(i)})
|
|
||||||
if err == nil && text == "{" {
|
|
||||||
validTokens[uint32(i)] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case StateInObject, StateInArray:
|
|
||||||
// Allow any token
|
|
||||||
for i := range logits {
|
|
||||||
validTokens[uint32(i)] = true
|
|
||||||
}
|
|
||||||
case StateInString:
|
|
||||||
// Allow any token except closing brace
|
|
||||||
for i := range logits {
|
|
||||||
text, err := s.proc.Decode([]int32{int32(i)})
|
|
||||||
if err == nil && text != "}" {
|
|
||||||
validTokens[uint32(i)] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case StateDone:
|
|
||||||
// No tokens allowed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark invalid tokens as NaN in one pass
|
|
||||||
for i := range logits {
|
|
||||||
if !validTokens[uint32(i)] {
|
|
||||||
logits[i] = math.NaN()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JSONSampler) UpdateState(tokenID int) error {
|
|
||||||
text, err := s.proc.Decode([]int32{int32(tokenID)})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to decode token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch s.state {
|
|
||||||
case StateStart:
|
|
||||||
if text != "{" {
|
|
||||||
return fmt.Errorf("expected {, got %s", text)
|
|
||||||
}
|
|
||||||
s.state = StateInObject
|
|
||||||
case StateInObject:
|
|
||||||
if text == "}" {
|
|
||||||
s.state = StateDone
|
|
||||||
}
|
|
||||||
case StateDone:
|
|
||||||
return fmt.Errorf("unexpected token after closing bracket: %s", text)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -165,9 +165,10 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
|||||||
if len(logitsCopy) == 0 {
|
if len(logitsCopy) == 0 {
|
||||||
return nil, errors.New("no valid tokens found")
|
return nil, errors.New("no valid tokens found")
|
||||||
}
|
}
|
||||||
|
logitsCopy, err := computeSoftmax(logitsCopy)
|
||||||
// usually, a softmax is applied to sample from the logits
|
if err != nil {
|
||||||
// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
|
return nil, err
|
||||||
|
}
|
||||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
w := sampleuv.NewWeighted(logitsCopy, nil)
|
||||||
if v, ok := w.Take(); ok {
|
if v, ok := w.Take(); ok {
|
||||||
// returns the token ID
|
// returns the token ID
|
||||||
@ -176,17 +177,6 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
|||||||
return nil, errors.New("weighed sampler failed")
|
return nil, errors.New("weighed sampler failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove after next PR merge
|
|
||||||
type greedy struct{}
|
|
||||||
|
|
||||||
func Greedy() Sampler {
|
|
||||||
return greedy{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (greedy) Sample(logits []float64) ([]float64, error) {
|
|
||||||
return []float64{float64(floats.MaxIdx(logits))}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
||||||
var err error
|
var err error
|
||||||
for _, sampler := range samplers {
|
for _, sampler := range samplers {
|
||||||
|
@ -3,14 +3,9 @@ package sample
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"runtime/trace"
|
|
||||||
|
|
||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,218 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type token []int32
|
|
||||||
|
|
||||||
type Node struct {
|
|
||||||
State JSONState
|
|
||||||
TransitionEdges map[*Node][]token
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNode(state JSONState) *Node {
|
|
||||||
return &Node{
|
|
||||||
State: state,
|
|
||||||
TransitionEdges: make(map[*Node][]token),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// startToken token
|
|
||||||
startTokenVariants []token
|
|
||||||
// endToken token
|
|
||||||
// stringToken token
|
|
||||||
// objectKeyToken token
|
|
||||||
tabToken token
|
|
||||||
spaceToken token
|
|
||||||
newlineToken token
|
|
||||||
newlineSpace token
|
|
||||||
// commaToken token
|
|
||||||
// commaToken2 token
|
|
||||||
// commaToken3 token
|
|
||||||
// colonToken token
|
|
||||||
// colonToken2 token
|
|
||||||
colonTokenVariants []token
|
|
||||||
commaTokenVariants []token
|
|
||||||
stringTokenVariants []token
|
|
||||||
endTokenVariants []token
|
|
||||||
objectKeyTokenVariants []token
|
|
||||||
objKeyToColonVariants []token
|
|
||||||
stringToObjectKeyVariants []token
|
|
||||||
stringToCommaVariants []token
|
|
||||||
stringToObjectVariants []token
|
|
||||||
stringEndToObjectEndVariants []token
|
|
||||||
stringEndToCommaVariants []token
|
|
||||||
)
|
|
||||||
|
|
||||||
func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
|
|
||||||
var allTokens token
|
|
||||||
for _, variant := range variants {
|
|
||||||
if t, err := proc.Encode(variant); err == nil {
|
|
||||||
allTokens = append(allTokens, t...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(allTokens) == 0 {
|
|
||||||
return nil, fmt.Errorf("no valid tokens found for variants")
|
|
||||||
}
|
|
||||||
return []token{allTokens}, nil
|
|
||||||
}
|
|
||||||
func initTokens(proc model.TextProcessor) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
s, err := proc.Decode([]int32{761})
|
|
||||||
fmt.Printf("761 decoded %q\n", s)
|
|
||||||
|
|
||||||
// Compute start token variants
|
|
||||||
startVariants := []string{"{", " {", "{\n", " {\n"}
|
|
||||||
startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Compute end token variants
|
|
||||||
endVariants := []string{"}", " }", "}\n", " }\n"}
|
|
||||||
endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute string token variants
|
|
||||||
// TODO: removed \n
|
|
||||||
stringVariants := []string{"\"", " \""}
|
|
||||||
stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
|
|
||||||
objectKeyTokenVariants = stringTokenVariants
|
|
||||||
// Compute whitespace tokens
|
|
||||||
tabToken, err = proc.Encode("\t")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
spaceToken, err = proc.Encode(" ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
newlineToken, err = proc.Encode("\n")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
newlineSpace, err = proc.Encode(" \n")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute colon variants
|
|
||||||
colonVariants := []string{":"}
|
|
||||||
colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute comma variants
|
|
||||||
commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
|
|
||||||
commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
|
|
||||||
stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
|
|
||||||
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
|
|
||||||
stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
|
|
||||||
stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildStateMachine(proc model.TextProcessor) (*Node, error) {
|
|
||||||
if err := initTokens(proc); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
startNode := NewNode(StateStart)
|
|
||||||
objectNode := NewNode(StateInObject)
|
|
||||||
objectKeyNode := NewNode(StateInObjectKey)
|
|
||||||
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
|
|
||||||
stringNode := NewNode(StateInString)
|
|
||||||
// intNode := NewNode(StateInInt)
|
|
||||||
commaNode := NewNode(StateInComma)
|
|
||||||
colonNode := NewNode(StateInColon)
|
|
||||||
stringEndNode := NewNode(StateInStringEnd)
|
|
||||||
endNode := NewNode(StateEnd)
|
|
||||||
terminateNode := NewNode(StateTerminate)
|
|
||||||
|
|
||||||
sentinelToken := token([]int32{-1})
|
|
||||||
// intSentinelToken := token([]int32{-2})
|
|
||||||
|
|
||||||
// TODO: cleanup connections of rules
|
|
||||||
startNode.TransitionEdges[objectNode] = startTokenVariants
|
|
||||||
|
|
||||||
objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
|
||||||
objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
|
||||||
objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
|
||||||
|
|
||||||
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
|
||||||
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
|
||||||
|
|
||||||
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
|
|
||||||
// characterize end of object key
|
|
||||||
objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
|
|
||||||
objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
|
|
||||||
|
|
||||||
// TODO: enable this - key -> object
|
|
||||||
// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
|
|
||||||
|
|
||||||
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
|
|
||||||
|
|
||||||
// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
|
||||||
// intNode.TransitionEdges[commaNode] = commaTokenVariants
|
|
||||||
// TODO: handle
|
|
||||||
// intNode.TransitionEdges[terminateNode] = endTokenVariants
|
|
||||||
|
|
||||||
commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
|
||||||
// commaNode.TransitionEdges[objectNode] = startTokenVariants
|
|
||||||
|
|
||||||
colonNode.TransitionEdges[stringNode] = stringTokenVariants
|
|
||||||
//TODO: enable
|
|
||||||
// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
|
||||||
colonNode.TransitionEdges[objectNode] = startTokenVariants
|
|
||||||
|
|
||||||
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
|
|
||||||
stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
|
|
||||||
// TODO: "\""," Case not accounted for
|
|
||||||
stringNode.TransitionEdges[commaNode] = stringToCommaVariants
|
|
||||||
|
|
||||||
// TODO: "\"",\"" Case not accounted for
|
|
||||||
stringNode.TransitionEdges[objectNode] = stringToObjectVariants
|
|
||||||
|
|
||||||
stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
|
|
||||||
stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
|
|
||||||
stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
|
|
||||||
// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
|
|
||||||
|
|
||||||
// Should be obj end
|
|
||||||
// TODO: handle
|
|
||||||
endNode.TransitionEdges[terminateNode] = []token{}
|
|
||||||
|
|
||||||
endNode.TransitionEdges[commaNode] = commaTokenVariants
|
|
||||||
|
|
||||||
terminateNode.TransitionEdges[terminateNode] = []token{}
|
|
||||||
return startNode, nil
|
|
||||||
}
|
|
@ -7,92 +7,6 @@ type StructuredOutput struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
||||||
// _, stateToNodeMap, err := BuildGraph(proc)
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
|
|
||||||
// // If no schema constraints, return original graph node
|
|
||||||
// if schema == nil {
|
|
||||||
// return graph
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Create a new node with same state
|
|
||||||
// constrainedNode := NewPDANode(graph.State)
|
|
||||||
|
|
||||||
// // Copy over existing transitions and masks
|
|
||||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
|
||||||
// for r, node := range graph.TransitionEdges {
|
|
||||||
// constrainedNode.TransitionEdges[r] = node
|
|
||||||
// }
|
|
||||||
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
|
|
||||||
|
|
||||||
// // Apply schema constraints based on type
|
|
||||||
// switch schema.EffectiveType() {
|
|
||||||
// case "object":
|
|
||||||
// // Only allow defined property names in object keys
|
|
||||||
// if graph.State == StateInObjectKey {
|
|
||||||
// // TODO: Add property name validation
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Constrain property values based on schema
|
|
||||||
// if graph.State == StateInColon || graph.State == StateInSpace {
|
|
||||||
// // Clear transitions to only allow valid types
|
|
||||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
|
||||||
|
|
||||||
// // Add transitions based on property schemas
|
|
||||||
// for _, prop := range schema.Properties {
|
|
||||||
// switch prop.EffectiveType() {
|
|
||||||
// case "object":
|
|
||||||
// if objNode, ok := graph.TransitionEdges['{']; ok {
|
|
||||||
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
|
|
||||||
// }
|
|
||||||
// case "array":
|
|
||||||
// if arrNode, ok := graph.TransitionEdges['[']; ok {
|
|
||||||
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
|
|
||||||
// }
|
|
||||||
// case "string":
|
|
||||||
// if strNode, ok := graph.TransitionEdges['"']; ok {
|
|
||||||
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
|
|
||||||
// }
|
|
||||||
// case "number":
|
|
||||||
// for _, r := range validNumberRunes {
|
|
||||||
// if numNode, ok := graph.TransitionEdges[r]; ok {
|
|
||||||
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// case "integer":
|
|
||||||
// for _, r := range validIntRunes {
|
|
||||||
// if intNode, ok := graph.TransitionEdges[r]; ok {
|
|
||||||
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// case "boolean":
|
|
||||||
// for _, r := range []rune{'t', 'f'} {
|
|
||||||
// if boolNode, ok := graph.TransitionEdges[r]; ok {
|
|
||||||
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// case "null":
|
|
||||||
// if nullNode, ok := graph.TransitionEdges['n']; ok {
|
|
||||||
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// case "array":
|
|
||||||
// // Constrain array items based on schema
|
|
||||||
// if schema.Items != nil {
|
|
||||||
// for r, node := range graph.TransitionEdges {
|
|
||||||
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return constrainedNode
|
|
||||||
// }
|
|
||||||
|
BIN
sample/trace.out
Normal file
BIN
sample/trace.out
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user