diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 60ae88dac..04258c46a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -443,6 +443,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) s.lc.Synchronize() } + var totalSamplingTime time.Duration for i, seq := range s.seqs { if seq == nil { continue @@ -477,8 +478,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } // sample a token + samplingStart := time.Now() token := seq.samplingCtx.Sample(s.lc, seq.iBatch) seq.samplingCtx.Accept(token, true) + samplingTime := time.Since(samplingStart) + totalSamplingTime += samplingTime + slog.Info("sampling time", "time", samplingTime) piece := s.model.TokenToPiece(token) seq.numPredicted++ @@ -635,6 +640,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar + start := time.Now() seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, @@ -642,6 +648,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams: &samplingParams, embedding: false, }) + slog.Info("new sequence created", "duration", time.Since(start)) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return diff --git a/model/cmd/main.go b/model/cmd/main.go index 0526ef032..258046880 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -28,7 +28,7 @@ var args struct { } func temp() error { - start := time.Now() + // start := time.Now() flag.IntVar(&args.n, "n", 10, "number of samples") flag.BoolVar(&args.debug, "debug", false, "enable debug logging") flag.StringVar(&args.image, "image", "", "path to image file") @@ -106,10 +106,12 @@ func temp() error { } } - // simple schema - // This schema maps to JSON like: + // Schema for a list of friends with their info + // Maps to JSON like: // { - // "name": "some string value" + // "name": "string", + // "age": integer, + // "is_available": boolean // } schema := &sample.Schema{ Name: "root", @@ -117,20 +119,24 @@ func temp() error { Properties: []*sample.Schema{ {Name: "name", Type: "string"}, {Name: "age", Type: "integer"}, - {Name: "is_student", Type: "boolean"}, - // {Name: "is_student", Type: "boolean"}, + {Name: "is_available", Type: "boolean"}, }, } - // pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) - pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor)) + // fmt.Println("schema", schema) + // schema = nil + jsonTransform, err := sample.NewJSONSampler(m.(model.TextProcessor), schema) if err != nil { return err } + transforms := []sample.Transform{ + jsonTransform, + } + var offset int var stringBuffer string - var ttft time.Duration + // var ttft time.Duration var totalSamplingTime time.Duration count := 0 for range args.n { @@ -139,24 +145,9 @@ func temp() error { return err } - // f64s := make([]float64, len(f32s)) - // for i, f32 := range f32s { - // f64s[i] = float64(f32) - // } - // samplers := []sample.Transform{ - // pushdownSampler, - // sample.Weighed(), - // sample.TopP(0.9), - // sample.Weighed(), - // sample.Greedy(), - // } - transforms := []sample.Transform{ - pushdownSampler, - } - samplingStart := time.Now() - sampler := sample.NewSampler(transforms, sample.Greedy()) - sampledIdx, err := sampler.Sample(logits.Floats()) + sampler := sample.Greedy() + sampledIdx, err := sampler.Sample(logits.Floats(), transforms...) if err != nil { return err } @@ -164,7 +155,7 @@ func temp() error { samplingTime := time.Since(samplingStart) totalSamplingTime += samplingTime - fmt.Println("sampling time", samplingTime) + // fmt.Println("sampling time", samplingTime) // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) var outputIDs []int32 @@ -184,10 +175,10 @@ func temp() error { return err } - if ttft == 0 { - ttft = time.Since(start) - fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds()) - } + // if ttft == 0 { + // ttft = time.Since(start) + // fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds()) + // } // fmt.Printf("--- token: %q\n", s) // fmt.Printf("--- outputIDs: %v\n", outputIDs) @@ -195,7 +186,7 @@ func temp() error { count++ fmt.Println("--- stringBuffer", stringBuffer) - err = pushdownSampler.UpdateState(outputIDs) + outputIDs, err = jsonTransform.UpdateState(outputIDs) if err != nil { return err } diff --git a/sample/constrained.go b/sample/constrained.go deleted file mode 100644 index e148ec3f9..000000000 --- a/sample/constrained.go +++ /dev/null @@ -1,49 +0,0 @@ -package sample - -import ( - "github.com/ollama/ollama/model" -) - -type ConstrainedSampler struct { - schema *Schema - propIdx int - propToNodeMap map[string]*PDA - pdaSampler *PushdownSampler - decodedToks []string -} - -func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) { - pdaSampler, err := NewPushdownSampler(proc) - if err != nil { - return nil, err - } - - // if schema == nil { - return &ConstrainedSampler{ - schema: nil, - propIdx: -1, - propToNodeMap: nil, - pdaSampler: pdaSampler, - }, nil - -} - -func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) { - if s.schema == nil { - return s.pdaSampler.Apply(logits) - } - - return nil, nil -} - -func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error { - if err := s.pdaSampler.UpdateState(tokenSlice); err != nil { - return err - } - - if s.schema == nil { - return nil - } - - return nil -} diff --git a/sample/json_types.go b/sample/json_types.go index bf52b8c9d..7bbb7b951 100644 --- a/sample/json_types.go +++ b/sample/json_types.go @@ -41,6 +41,7 @@ const ( StateTerminate StateInObjectEnd StateTransitioningToTerminate + StateInListStartJSON ) var JSONStates = []JSONState{ @@ -48,6 +49,7 @@ var JSONStates = []JSONState{ StateInObject, StateInObjectKey, StateInStructuredKey, + StateInStructuredValue, StateNewline, StateTab, StateSpace, @@ -63,6 +65,7 @@ var JSONStates = []JSONState{ StateInSpaceEndValue, StateInNewlineEndValue, StateInObjSpace, + StateInListStartJSON, StateInList, StateInListComma, StateInValue, @@ -89,6 +92,8 @@ func (s JSONState) String() string { return "StateInObjectKey" case StateInStructuredKey: return "StateInStructuredKey" + case StateInStructuredValue: + return "StateInStructuredValue" case StateNewline: return "StateNewline" case StateTab: @@ -112,21 +117,27 @@ func (s JSONState) String() string { case StateInTab: return "StateInTab" case StateInSpaceToValue: - return "StateInSpace" + return "StateInSpaceToValue" + case StateInSpaceEndValue: + return "StateInSpaceEndValue" + case StateInNewlineEndValue: + return "StateInNewlineEndValue" case StateInObjSpace: return "StateInObjSpace" case StateInList: return "StateInList" - case StateInListObjectEnd: - return "StateInListObjectEnd" case StateInListComma: return "StateInListComma" + case StateInValue: + return "StateInValue" + case StateInValueEnd: + return "StateInValueEnd" case StateInListEnd: return "StateInListEnd" + case StateInListObjectEnd: + return "StateInListObjectEnd" case StateInNewline: return "StateInNewline" - case StateInNewlineEndValue: - return "StateInNewlineEndValue" case StateInNumber: return "StateInNumber" case StateInNumberEnd: @@ -135,12 +146,14 @@ func (s JSONState) String() string { return "StateInStringEnd" case StateInObjectKeyEnd: return "StateInObjectKeyEnd" - case StateInSpaceEndValue: - return "StateInSpaceEndValue" case StateTerminate: return "StateTerminate" case StateInObjectEnd: return "StateInObjectEnd" + case StateTransitioningToTerminate: + return "StateTransitioningToTerminate" + case StateInListStartJSON: + return "StateInListStartJSON" default: return fmt.Sprintf("Unknown state: %d", s) } diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index 366d830bc..be3fde1fb 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -37,8 +37,10 @@ Key JSON rules to consider: // TODO: / should be valid but an escape character var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} -var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} -var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} +var ( + intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} + validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} +) var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'} @@ -61,9 +63,10 @@ func NewPDANode(state JSONState) *PDA { } type PDAGraphBuilder struct { - proc model.TextProcessor - decodedToks []string - stateToNodeMap map[JSONState]*PDA + proc model.TextProcessor + decodedToks []string + stateToNodeMap map[JSONState]*PDA + tokenToStatesMap map[int32][]JSONState } func (b *PDAGraphBuilder) BuildGraph() error { @@ -73,20 +76,26 @@ func (b *PDAGraphBuilder) BuildGraph() error { } stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] - stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] + stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON] + + // TODO: update naming here - and revisit values + stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON] stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] // new line stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + // stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject] // new line end value - stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] @@ -108,6 +117,8 @@ func (b *PDAGraphBuilder) BuildGraph() error { // where values should be // this could be combined but the probl might change, we're alr doing a skip ahead stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] + stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue] + stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) @@ -117,6 +128,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap) stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue] // Values // string node @@ -125,7 +137,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { // String end node addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) - stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // TODO: add counters for allowable number of decimals, e, E, etc @@ -134,7 +146,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] } addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) - stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // list node @@ -142,10 +154,12 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] + // early end + stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] // list end node stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] - stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -166,6 +180,9 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] + stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList] + stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList] + addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) // list object end @@ -180,7 +197,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { } stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] addEnds(stateToNodeMap[StateInBool], stateToNodeMap) - stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] // comma node @@ -188,9 +205,11 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + // todo: review this space transition + // stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] // space end value - stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + // stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -221,6 +240,12 @@ func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) { func (b *PDAGraphBuilder) preComputeValidStates() error { for _, node := range b.stateToNodeMap { + // if node.State == StateInObjectKey { + // if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 { + // b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode + // fmt.Println("copying string mask to object key mask") + // } + // } if err := b.CreateMask(node); err != nil { return err } @@ -228,6 +253,20 @@ func (b *PDAGraphBuilder) preComputeValidStates() error { return nil } +func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error { + // TODO: make can be somewhere else too + b.tokenToStatesMap = make(map[int32][]JSONState) + for i, t := range b.decodedToks { + for _, r := range t { + if r == '"' { + b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString) + } + } + } + return nil +} + +// TODO: the mask for obj key and string should be the same? func (b *PDAGraphBuilder) CreateMask(node *PDA) error { if node == nil { return fmt.Errorf("node cannot be nil") diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index 592582375..4fdf92596 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -10,9 +10,13 @@ import ( ) // TODO: safety in case of invalid json +// TODO: partial JSON matching? // TODO: interfaces to cleanup with return values // TODO this interface shouldn't be the sampler - should just use Sampler // TODO: add penalties for string \n stuff +// TODO: minimize number of fwd passes if there is only one match +// TODO: greedy sample initially and then backtrack if no match + type PushdownSampler struct { PDAGraphBuilder curNode *PDA @@ -140,16 +144,24 @@ func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) { return logits, nil } -func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { +func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) { fmt.Println("current state - updating", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { - return err + return nil, err } fmt.Printf(">>> mappedString: %q\n", mappedString) - // TODO: should force closing for all braces - not doing square yet + // flag := -1 + // endBraceRunes := []rune{'}', ']'} for _, r := range mappedString { + // TODO: if this is enabled again, make sure to appropriately handle the state transitions + // if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 { + // fmt.Printf("stack is empty, extra closing brace %c\n", r) + // // flag = i + // break + + // } if r == rune('{') { s.braceStack = append(s.braceStack, r) } @@ -158,32 +170,36 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { } if r == rune('}') { if len(s.braceStack) == 0 { - return fmt.Errorf("stack is empty, extra closing brace %c", r) + return nil, fmt.Errorf("stack is empty, extra closing brace %c", r) } top := s.braceStack[len(s.braceStack)-1] if top != rune('{') { - return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') + return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') } s.braceStack = s.braceStack[:len(s.braceStack)-1] } if r == rune(']') { if len(s.braceStack) == 0 { - return fmt.Errorf("stack is empty, extra closing brace %c", r) + return nil, fmt.Errorf("stack is empty, extra closing brace %c", r) } top := s.braceStack[len(s.braceStack)-1] if top != rune('[') { - return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') + return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') } s.braceStack = s.braceStack[:len(s.braceStack)-1] } } + // if flag != -1 { + // tokenSlice = tokenSlice[:flag] + // } + // fmt.Println("flag!", flag) for _, tokenID := range tokenSlice { // transition to the next node nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] if !ok { - return fmt.Errorf("invalid token: %q", mappedString) + return nil, fmt.Errorf("invalid token: %q", mappedString) } fmt.Println("transitioning to", nextNode.State) @@ -196,12 +212,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { s.curNode = nextNode fmt.Println("updated curNode state", s.curNode.State) } - return nil + return tokenSlice, nil } // greedy sample + backtrack? func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) { - // Create a new slice with same length as logits, initialized to -Inf maskedLogits := make([]float64, len(logits)) for i := range maskedLogits { diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 2ae16d062..14bea0a03 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -35,7 +35,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err }, nil } - fmt.Println("schema not nil") + // fmt.Println("schema not nil") so := &JSONSampler{ schema: schema, propIdx: -1, @@ -87,7 +87,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err for curState == StateInStructuredKey { // there is only one edge for r, toNode := range fromNode.TransitionEdges { - // fmt.Println("rune", r, "edge", toNode.State) + fmt.Println("rune", r, "edge", toNode.State) so.pdaSampler.CreateMask(toNode) fmt.Printf("created mask for %c\n", r) curState = toNode.State @@ -96,13 +96,27 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err fromNode = toNode } } + + if curState != StateInColon { + return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState) + } + + // so.pdaSampler.CreateMask(fromNode) + + fromNode = fromNode.TransitionEdges[' '] + + so.pdaSampler.CreateMask(fromNode) + curState = fromNode.State + for _, toNode := range fromNode.TransitionEdges { + fmt.Println("toNode", toNode.State) + } } - runtime.ReadMemStats(&m) - after = m.Alloc - fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) - fmt.Printf("Mask creation time = %v\n", time.Since(start)) - fmt.Println("--------------------------------") + // runtime.ReadMemStats(&m) + // after = m.Alloc + // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + // fmt.Printf("Mask creation time = %v\n", time.Since(start)) + // fmt.Println("--------------------------------") return so, nil } @@ -130,14 +144,66 @@ func (s *JSONSampler) schemaToGraph() { TransitionEdges: make(map[rune]*PDA), MaskTokenIDToNode: make(map[int32]*PDA), } - fmt.Println("runeNode created", runeNode.State) - fmt.Printf("runeNode created %c\n", r) + // fmt.Println("runeNode created", runeNode.State) + // fmt.Printf("runeNode created %c\n", r) + // since alloc on heap connections wil still map prevNode.TransitionEdges[r] = runeNode prevNode = runeNode } + // point to end of object key node after all chars are done - prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] + // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] + + // link to value node + // Create a node for the end of the key (after the closing quote) + stringEndNode := &PDA{ + State: StateInStructuredKey, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges['"'] = stringEndNode + prevNode = stringEndNode + + // Add transition for colon after key + colonNode := &PDA{ + State: StateInColon, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges[':'] = colonNode + prevNode = colonNode + + // Add transition for space after colon + spaceNode := &PDA{ + State: StateInSpaceToValue, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges[' '] = spaceNode + prevNode = spaceNode + + value := prop.Type + switch value { + case "object": + fmt.Println("object under key: ", name) + case "array": + fmt.Println("array under key: ", name) + case "string": + fmt.Println("string under key: ", name) + prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString] + case "number": + fmt.Println("number under key: ", name) + for _, r := range validNumberRunes { + prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber] + } + case "boolean": + fmt.Println("boolean under key: ", name) + prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool] + prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool] + prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull] + } + // points to start of the key s.propToNodeMap[name] = keyNode fmt.Println("name", name, "keyNode", keyNode.State) @@ -152,7 +218,7 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) { } switch s.pdaSampler.curNode.State { - // doesnt account for multi rune case + // TODO: doesnt account for multi rune case case StateInObjectKey: if s.propIdx > len(s.schema.Properties)-1 { return nil, fmt.Errorf("propIdx out of bounds") @@ -196,18 +262,17 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) { } return s.pdaSampler.Apply(logits) } - } -func (s *JSONSampler) UpdateState(tokenSlice []int32) error { - err := s.pdaSampler.UpdateState(tokenSlice) +func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) { + tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice) if err != nil { - return err + return nil, err } if s.schema == nil { // Don't need to update state for unconstrained JSON sampling - return nil + return tokenSlice, nil } switch s.pdaSampler.curNode.State { @@ -217,14 +282,15 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error { prop := s.schema.Properties[s.propIdx] fmt.Println("prop", prop.Name) s.pdaSampler.curNode = s.propToNodeMap[prop.Name] - str, err := s.pdaSampler.proc.Decode(tokenSlice) - if err != nil { - return err - } - fmt.Println("str", str) + // TODO: this does not work - mike + // str, err := s.pdaSampler.proc.Decode(tokenSlice) + // if err != nil { + // return nil, err + // } + // fmt.Println("str", str) - return nil + return tokenSlice, nil default: - return nil + return tokenSlice, nil } }