wip
This commit is contained in:
parent
ffd6428758
commit
a4265c278a
@ -443,6 +443,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
s.lc.Synchronize()
|
s.lc.Synchronize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var totalSamplingTime time.Duration
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
@ -477,8 +478,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
|
samplingStart := time.Now()
|
||||||
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
||||||
seq.samplingCtx.Accept(token, true)
|
seq.samplingCtx.Accept(token, true)
|
||||||
|
samplingTime := time.Since(samplingStart)
|
||||||
|
totalSamplingTime += samplingTime
|
||||||
|
slog.Info("sampling time", "time", samplingTime)
|
||||||
piece := s.model.TokenToPiece(token)
|
piece := s.model.TokenToPiece(token)
|
||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
@ -635,6 +640,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
samplingParams.Seed = uint32(req.Seed)
|
samplingParams.Seed = uint32(req.Seed)
|
||||||
samplingParams.Grammar = req.Grammar
|
samplingParams.Grammar = req.Grammar
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
@ -642,6 +648,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
|
slog.Info("new sequence created", "duration", time.Since(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -28,7 +28,7 @@ var args struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func temp() error {
|
func temp() error {
|
||||||
start := time.Now()
|
// start := time.Now()
|
||||||
flag.IntVar(&args.n, "n", 10, "number of samples")
|
flag.IntVar(&args.n, "n", 10, "number of samples")
|
||||||
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
||||||
flag.StringVar(&args.image, "image", "", "path to image file")
|
flag.StringVar(&args.image, "image", "", "path to image file")
|
||||||
@ -106,10 +106,12 @@ func temp() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// simple schema
|
// Schema for a list of friends with their info
|
||||||
// This schema maps to JSON like:
|
// Maps to JSON like:
|
||||||
// {
|
// {
|
||||||
// "name": "some string value"
|
// "name": "string",
|
||||||
|
// "age": integer,
|
||||||
|
// "is_available": boolean
|
||||||
// }
|
// }
|
||||||
schema := &sample.Schema{
|
schema := &sample.Schema{
|
||||||
Name: "root",
|
Name: "root",
|
||||||
@ -117,20 +119,24 @@ func temp() error {
|
|||||||
Properties: []*sample.Schema{
|
Properties: []*sample.Schema{
|
||||||
{Name: "name", Type: "string"},
|
{Name: "name", Type: "string"},
|
||||||
{Name: "age", Type: "integer"},
|
{Name: "age", Type: "integer"},
|
||||||
{Name: "is_student", Type: "boolean"},
|
{Name: "is_available", Type: "boolean"},
|
||||||
// {Name: "is_student", Type: "boolean"},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
// fmt.Println("schema", schema)
|
||||||
pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
|
// schema = nil
|
||||||
|
jsonTransform, err := sample.NewJSONSampler(m.(model.TextProcessor), schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
transforms := []sample.Transform{
|
||||||
|
jsonTransform,
|
||||||
|
}
|
||||||
|
|
||||||
var offset int
|
var offset int
|
||||||
var stringBuffer string
|
var stringBuffer string
|
||||||
var ttft time.Duration
|
// var ttft time.Duration
|
||||||
var totalSamplingTime time.Duration
|
var totalSamplingTime time.Duration
|
||||||
count := 0
|
count := 0
|
||||||
for range args.n {
|
for range args.n {
|
||||||
@ -139,24 +145,9 @@ func temp() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// f64s := make([]float64, len(f32s))
|
|
||||||
// for i, f32 := range f32s {
|
|
||||||
// f64s[i] = float64(f32)
|
|
||||||
// }
|
|
||||||
// samplers := []sample.Transform{
|
|
||||||
// pushdownSampler,
|
|
||||||
// sample.Weighed(),
|
|
||||||
// sample.TopP(0.9),
|
|
||||||
// sample.Weighed(),
|
|
||||||
// sample.Greedy(),
|
|
||||||
// }
|
|
||||||
transforms := []sample.Transform{
|
|
||||||
pushdownSampler,
|
|
||||||
}
|
|
||||||
|
|
||||||
samplingStart := time.Now()
|
samplingStart := time.Now()
|
||||||
sampler := sample.NewSampler(transforms, sample.Greedy())
|
sampler := sample.Greedy()
|
||||||
sampledIdx, err := sampler.Sample(logits.Floats())
|
sampledIdx, err := sampler.Sample(logits.Floats(), transforms...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -164,7 +155,7 @@ func temp() error {
|
|||||||
samplingTime := time.Since(samplingStart)
|
samplingTime := time.Since(samplingStart)
|
||||||
totalSamplingTime += samplingTime
|
totalSamplingTime += samplingTime
|
||||||
|
|
||||||
fmt.Println("sampling time", samplingTime)
|
// fmt.Println("sampling time", samplingTime)
|
||||||
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||||
|
|
||||||
var outputIDs []int32
|
var outputIDs []int32
|
||||||
@ -184,10 +175,10 @@ func temp() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ttft == 0 {
|
// if ttft == 0 {
|
||||||
ttft = time.Since(start)
|
// ttft = time.Since(start)
|
||||||
fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
|
// fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
|
||||||
}
|
// }
|
||||||
|
|
||||||
// fmt.Printf("--- token: %q\n", s)
|
// fmt.Printf("--- token: %q\n", s)
|
||||||
// fmt.Printf("--- outputIDs: %v\n", outputIDs)
|
// fmt.Printf("--- outputIDs: %v\n", outputIDs)
|
||||||
@ -195,7 +186,7 @@ func temp() error {
|
|||||||
count++
|
count++
|
||||||
fmt.Println("--- stringBuffer", stringBuffer)
|
fmt.Println("--- stringBuffer", stringBuffer)
|
||||||
|
|
||||||
err = pushdownSampler.UpdateState(outputIDs)
|
outputIDs, err = jsonTransform.UpdateState(outputIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ConstrainedSampler struct {
|
|
||||||
schema *Schema
|
|
||||||
propIdx int
|
|
||||||
propToNodeMap map[string]*PDA
|
|
||||||
pdaSampler *PushdownSampler
|
|
||||||
decodedToks []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
|
|
||||||
pdaSampler, err := NewPushdownSampler(proc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if schema == nil {
|
|
||||||
return &ConstrainedSampler{
|
|
||||||
schema: nil,
|
|
||||||
propIdx: -1,
|
|
||||||
propToNodeMap: nil,
|
|
||||||
pdaSampler: pdaSampler,
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
|
|
||||||
if s.schema == nil {
|
|
||||||
return s.pdaSampler.Apply(logits)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
|
|
||||||
if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.schema == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -41,6 +41,7 @@ const (
|
|||||||
StateTerminate
|
StateTerminate
|
||||||
StateInObjectEnd
|
StateInObjectEnd
|
||||||
StateTransitioningToTerminate
|
StateTransitioningToTerminate
|
||||||
|
StateInListStartJSON
|
||||||
)
|
)
|
||||||
|
|
||||||
var JSONStates = []JSONState{
|
var JSONStates = []JSONState{
|
||||||
@ -48,6 +49,7 @@ var JSONStates = []JSONState{
|
|||||||
StateInObject,
|
StateInObject,
|
||||||
StateInObjectKey,
|
StateInObjectKey,
|
||||||
StateInStructuredKey,
|
StateInStructuredKey,
|
||||||
|
StateInStructuredValue,
|
||||||
StateNewline,
|
StateNewline,
|
||||||
StateTab,
|
StateTab,
|
||||||
StateSpace,
|
StateSpace,
|
||||||
@ -63,6 +65,7 @@ var JSONStates = []JSONState{
|
|||||||
StateInSpaceEndValue,
|
StateInSpaceEndValue,
|
||||||
StateInNewlineEndValue,
|
StateInNewlineEndValue,
|
||||||
StateInObjSpace,
|
StateInObjSpace,
|
||||||
|
StateInListStartJSON,
|
||||||
StateInList,
|
StateInList,
|
||||||
StateInListComma,
|
StateInListComma,
|
||||||
StateInValue,
|
StateInValue,
|
||||||
@ -89,6 +92,8 @@ func (s JSONState) String() string {
|
|||||||
return "StateInObjectKey"
|
return "StateInObjectKey"
|
||||||
case StateInStructuredKey:
|
case StateInStructuredKey:
|
||||||
return "StateInStructuredKey"
|
return "StateInStructuredKey"
|
||||||
|
case StateInStructuredValue:
|
||||||
|
return "StateInStructuredValue"
|
||||||
case StateNewline:
|
case StateNewline:
|
||||||
return "StateNewline"
|
return "StateNewline"
|
||||||
case StateTab:
|
case StateTab:
|
||||||
@ -112,21 +117,27 @@ func (s JSONState) String() string {
|
|||||||
case StateInTab:
|
case StateInTab:
|
||||||
return "StateInTab"
|
return "StateInTab"
|
||||||
case StateInSpaceToValue:
|
case StateInSpaceToValue:
|
||||||
return "StateInSpace"
|
return "StateInSpaceToValue"
|
||||||
|
case StateInSpaceEndValue:
|
||||||
|
return "StateInSpaceEndValue"
|
||||||
|
case StateInNewlineEndValue:
|
||||||
|
return "StateInNewlineEndValue"
|
||||||
case StateInObjSpace:
|
case StateInObjSpace:
|
||||||
return "StateInObjSpace"
|
return "StateInObjSpace"
|
||||||
case StateInList:
|
case StateInList:
|
||||||
return "StateInList"
|
return "StateInList"
|
||||||
case StateInListObjectEnd:
|
|
||||||
return "StateInListObjectEnd"
|
|
||||||
case StateInListComma:
|
case StateInListComma:
|
||||||
return "StateInListComma"
|
return "StateInListComma"
|
||||||
|
case StateInValue:
|
||||||
|
return "StateInValue"
|
||||||
|
case StateInValueEnd:
|
||||||
|
return "StateInValueEnd"
|
||||||
case StateInListEnd:
|
case StateInListEnd:
|
||||||
return "StateInListEnd"
|
return "StateInListEnd"
|
||||||
|
case StateInListObjectEnd:
|
||||||
|
return "StateInListObjectEnd"
|
||||||
case StateInNewline:
|
case StateInNewline:
|
||||||
return "StateInNewline"
|
return "StateInNewline"
|
||||||
case StateInNewlineEndValue:
|
|
||||||
return "StateInNewlineEndValue"
|
|
||||||
case StateInNumber:
|
case StateInNumber:
|
||||||
return "StateInNumber"
|
return "StateInNumber"
|
||||||
case StateInNumberEnd:
|
case StateInNumberEnd:
|
||||||
@ -135,12 +146,14 @@ func (s JSONState) String() string {
|
|||||||
return "StateInStringEnd"
|
return "StateInStringEnd"
|
||||||
case StateInObjectKeyEnd:
|
case StateInObjectKeyEnd:
|
||||||
return "StateInObjectKeyEnd"
|
return "StateInObjectKeyEnd"
|
||||||
case StateInSpaceEndValue:
|
|
||||||
return "StateInSpaceEndValue"
|
|
||||||
case StateTerminate:
|
case StateTerminate:
|
||||||
return "StateTerminate"
|
return "StateTerminate"
|
||||||
case StateInObjectEnd:
|
case StateInObjectEnd:
|
||||||
return "StateInObjectEnd"
|
return "StateInObjectEnd"
|
||||||
|
case StateTransitioningToTerminate:
|
||||||
|
return "StateTransitioningToTerminate"
|
||||||
|
case StateInListStartJSON:
|
||||||
|
return "StateInListStartJSON"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("Unknown state: %d", s)
|
return fmt.Sprintf("Unknown state: %d", s)
|
||||||
}
|
}
|
||||||
|
@ -37,8 +37,10 @@ Key JSON rules to consider:
|
|||||||
// TODO: / should be valid but an escape character
|
// TODO: / should be valid but an escape character
|
||||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||||
|
|
||||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
var (
|
||||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
|
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
|
)
|
||||||
|
|
||||||
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||||
|
|
||||||
@ -61,9 +63,10 @@ func NewPDANode(state JSONState) *PDA {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PDAGraphBuilder struct {
|
type PDAGraphBuilder struct {
|
||||||
proc model.TextProcessor
|
proc model.TextProcessor
|
||||||
decodedToks []string
|
decodedToks []string
|
||||||
stateToNodeMap map[JSONState]*PDA
|
stateToNodeMap map[JSONState]*PDA
|
||||||
|
tokenToStatesMap map[int32][]JSONState
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *PDAGraphBuilder) BuildGraph() error {
|
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||||
@ -73,20 +76,26 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
|
// TODO: update naming here - and revisit values
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
// new line
|
// new line
|
||||||
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||||
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
|
||||||
// new line end value
|
// new line end value
|
||||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
@ -108,6 +117,8 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
// where values should be
|
// where values should be
|
||||||
// this could be combined but the probl might change, we're alr doing a skip ahead
|
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||||
@ -117,6 +128,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
// Values
|
// Values
|
||||||
// string node
|
// string node
|
||||||
@ -125,7 +137,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
|
|
||||||
// String end node
|
// String end node
|
||||||
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||||
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
@ -134,7 +146,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
}
|
}
|
||||||
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||||
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
// list node
|
// list node
|
||||||
@ -142,10 +154,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
// early end
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
// list end node
|
// list end node
|
||||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
@ -166,6 +180,9 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||||
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
|
||||||
|
|
||||||
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||||
|
|
||||||
// list object end
|
// list object end
|
||||||
@ -180,7 +197,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
}
|
}
|
||||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||||
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
// comma node
|
// comma node
|
||||||
@ -188,9 +205,11 @@ func (b *PDAGraphBuilder) BuildGraph() error {
|
|||||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// todo: review this space transition
|
||||||
|
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
// space end value
|
// space end value
|
||||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
@ -221,6 +240,12 @@ func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
|||||||
|
|
||||||
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||||
for _, node := range b.stateToNodeMap {
|
for _, node := range b.stateToNodeMap {
|
||||||
|
// if node.State == StateInObjectKey {
|
||||||
|
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
|
||||||
|
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
|
||||||
|
// fmt.Println("copying string mask to object key mask")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
if err := b.CreateMask(node); err != nil {
|
if err := b.CreateMask(node); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -228,6 +253,20 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
|
||||||
|
// TODO: make can be somewhere else too
|
||||||
|
b.tokenToStatesMap = make(map[int32][]JSONState)
|
||||||
|
for i, t := range b.decodedToks {
|
||||||
|
for _, r := range t {
|
||||||
|
if r == '"' {
|
||||||
|
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: the mask for obj key and string should be the same?
|
||||||
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return fmt.Errorf("node cannot be nil")
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
@ -10,9 +10,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TODO: safety in case of invalid json
|
// TODO: safety in case of invalid json
|
||||||
|
// TODO: partial JSON matching?
|
||||||
// TODO: interfaces to cleanup with return values
|
// TODO: interfaces to cleanup with return values
|
||||||
// TODO this interface shouldn't be the sampler - should just use Sampler
|
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||||
// TODO: add penalties for string \n stuff
|
// TODO: add penalties for string \n stuff
|
||||||
|
// TODO: minimize number of fwd passes if there is only one match
|
||||||
|
// TODO: greedy sample initially and then backtrack if no match
|
||||||
|
|
||||||
type PushdownSampler struct {
|
type PushdownSampler struct {
|
||||||
PDAGraphBuilder
|
PDAGraphBuilder
|
||||||
curNode *PDA
|
curNode *PDA
|
||||||
@ -140,16 +144,24 @@ func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
|
|||||||
return logits, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
fmt.Println("current state - updating", s.curNode.State)
|
fmt.Println("current state - updating", s.curNode.State)
|
||||||
mappedString, err := s.proc.Decode(tokenSlice)
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||||
|
|
||||||
// TODO: should force closing for all braces - not doing square yet
|
// flag := -1
|
||||||
|
// endBraceRunes := []rune{'}', ']'}
|
||||||
for _, r := range mappedString {
|
for _, r := range mappedString {
|
||||||
|
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
||||||
|
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
||||||
|
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
||||||
|
// // flag = i
|
||||||
|
// break
|
||||||
|
|
||||||
|
// }
|
||||||
if r == rune('{') {
|
if r == rune('{') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
}
|
}
|
||||||
@ -158,32 +170,36 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
}
|
}
|
||||||
if r == rune('}') {
|
if r == rune('}') {
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
}
|
}
|
||||||
top := s.braceStack[len(s.braceStack)-1]
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
if top != rune('{') {
|
if top != rune('{') {
|
||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
if r == rune(']') {
|
if r == rune(']') {
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
}
|
}
|
||||||
top := s.braceStack[len(s.braceStack)-1]
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
if top != rune('[') {
|
if top != rune('[') {
|
||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if flag != -1 {
|
||||||
|
// tokenSlice = tokenSlice[:flag]
|
||||||
|
// }
|
||||||
|
// fmt.Println("flag!", flag)
|
||||||
for _, tokenID := range tokenSlice {
|
for _, tokenID := range tokenSlice {
|
||||||
// transition to the next node
|
// transition to the next node
|
||||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid token: %q", mappedString)
|
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
||||||
}
|
}
|
||||||
fmt.Println("transitioning to", nextNode.State)
|
fmt.Println("transitioning to", nextNode.State)
|
||||||
|
|
||||||
@ -196,12 +212,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
s.curNode = nextNode
|
s.curNode = nextNode
|
||||||
fmt.Println("updated curNode state", s.curNode.State)
|
fmt.Println("updated curNode state", s.curNode.State)
|
||||||
}
|
}
|
||||||
return nil
|
return tokenSlice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// greedy sample + backtrack?
|
// greedy sample + backtrack?
|
||||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
|
func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
|
||||||
|
|
||||||
// Create a new slice with same length as logits, initialized to -Inf
|
// Create a new slice with same length as logits, initialized to -Inf
|
||||||
maskedLogits := make([]float64, len(logits))
|
maskedLogits := make([]float64, len(logits))
|
||||||
for i := range maskedLogits {
|
for i := range maskedLogits {
|
||||||
|
@ -35,7 +35,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("schema not nil")
|
// fmt.Println("schema not nil")
|
||||||
so := &JSONSampler{
|
so := &JSONSampler{
|
||||||
schema: schema,
|
schema: schema,
|
||||||
propIdx: -1,
|
propIdx: -1,
|
||||||
@ -87,7 +87,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
|
|||||||
for curState == StateInStructuredKey {
|
for curState == StateInStructuredKey {
|
||||||
// there is only one edge
|
// there is only one edge
|
||||||
for r, toNode := range fromNode.TransitionEdges {
|
for r, toNode := range fromNode.TransitionEdges {
|
||||||
// fmt.Println("rune", r, "edge", toNode.State)
|
fmt.Println("rune", r, "edge", toNode.State)
|
||||||
so.pdaSampler.CreateMask(toNode)
|
so.pdaSampler.CreateMask(toNode)
|
||||||
fmt.Printf("created mask for %c\n", r)
|
fmt.Printf("created mask for %c\n", r)
|
||||||
curState = toNode.State
|
curState = toNode.State
|
||||||
@ -96,13 +96,27 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
|
|||||||
fromNode = toNode
|
fromNode = toNode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if curState != StateInColon {
|
||||||
|
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// so.pdaSampler.CreateMask(fromNode)
|
||||||
|
|
||||||
|
fromNode = fromNode.TransitionEdges[' ']
|
||||||
|
|
||||||
|
so.pdaSampler.CreateMask(fromNode)
|
||||||
|
curState = fromNode.State
|
||||||
|
for _, toNode := range fromNode.TransitionEdges {
|
||||||
|
fmt.Println("toNode", toNode.State)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime.ReadMemStats(&m)
|
// runtime.ReadMemStats(&m)
|
||||||
after = m.Alloc
|
// after = m.Alloc
|
||||||
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||||
fmt.Println("--------------------------------")
|
// fmt.Println("--------------------------------")
|
||||||
|
|
||||||
return so, nil
|
return so, nil
|
||||||
}
|
}
|
||||||
@ -130,14 +144,66 @@ func (s *JSONSampler) schemaToGraph() {
|
|||||||
TransitionEdges: make(map[rune]*PDA),
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
}
|
}
|
||||||
fmt.Println("runeNode created", runeNode.State)
|
// fmt.Println("runeNode created", runeNode.State)
|
||||||
fmt.Printf("runeNode created %c\n", r)
|
// fmt.Printf("runeNode created %c\n", r)
|
||||||
|
|
||||||
// since alloc on heap connections wil still map
|
// since alloc on heap connections wil still map
|
||||||
prevNode.TransitionEdges[r] = runeNode
|
prevNode.TransitionEdges[r] = runeNode
|
||||||
prevNode = runeNode
|
prevNode = runeNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// point to end of object key node after all chars are done
|
// point to end of object key node after all chars are done
|
||||||
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||||
|
|
||||||
|
// link to value node
|
||||||
|
// Create a node for the end of the key (after the closing quote)
|
||||||
|
stringEndNode := &PDA{
|
||||||
|
State: StateInStructuredKey,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
prevNode = stringEndNode
|
||||||
|
|
||||||
|
// Add transition for colon after key
|
||||||
|
colonNode := &PDA{
|
||||||
|
State: StateInColon,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[':'] = colonNode
|
||||||
|
prevNode = colonNode
|
||||||
|
|
||||||
|
// Add transition for space after colon
|
||||||
|
spaceNode := &PDA{
|
||||||
|
State: StateInSpaceToValue,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[' '] = spaceNode
|
||||||
|
prevNode = spaceNode
|
||||||
|
|
||||||
|
value := prop.Type
|
||||||
|
switch value {
|
||||||
|
case "object":
|
||||||
|
fmt.Println("object under key: ", name)
|
||||||
|
case "array":
|
||||||
|
fmt.Println("array under key: ", name)
|
||||||
|
case "string":
|
||||||
|
fmt.Println("string under key: ", name)
|
||||||
|
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
||||||
|
case "number":
|
||||||
|
fmt.Println("number under key: ", name)
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
case "boolean":
|
||||||
|
fmt.Println("boolean under key: ", name)
|
||||||
|
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
// points to start of the key
|
// points to start of the key
|
||||||
s.propToNodeMap[name] = keyNode
|
s.propToNodeMap[name] = keyNode
|
||||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||||
@ -152,7 +218,7 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch s.pdaSampler.curNode.State {
|
switch s.pdaSampler.curNode.State {
|
||||||
// doesnt account for multi rune case
|
// TODO: doesnt account for multi rune case
|
||||||
case StateInObjectKey:
|
case StateInObjectKey:
|
||||||
if s.propIdx > len(s.schema.Properties)-1 {
|
if s.propIdx > len(s.schema.Properties)-1 {
|
||||||
return nil, fmt.Errorf("propIdx out of bounds")
|
return nil, fmt.Errorf("propIdx out of bounds")
|
||||||
@ -196,18 +262,17 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
return s.pdaSampler.Apply(logits)
|
return s.pdaSampler.Apply(logits)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
err := s.pdaSampler.UpdateState(tokenSlice)
|
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.schema == nil {
|
if s.schema == nil {
|
||||||
// Don't need to update state for unconstrained JSON sampling
|
// Don't need to update state for unconstrained JSON sampling
|
||||||
return nil
|
return tokenSlice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch s.pdaSampler.curNode.State {
|
switch s.pdaSampler.curNode.State {
|
||||||
@ -217,14 +282,15 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
prop := s.schema.Properties[s.propIdx]
|
prop := s.schema.Properties[s.propIdx]
|
||||||
fmt.Println("prop", prop.Name)
|
fmt.Println("prop", prop.Name)
|
||||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
// TODO: this does not work - mike
|
||||||
if err != nil {
|
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||||
return err
|
// if err != nil {
|
||||||
}
|
// return nil, err
|
||||||
fmt.Println("str", str)
|
// }
|
||||||
|
// fmt.Println("str", str)
|
||||||
|
|
||||||
return nil
|
return tokenSlice, nil
|
||||||
default:
|
default:
|
||||||
return nil
|
return tokenSlice, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user