diff --git a/model/cmd/main.go b/model/cmd/main.go index 9d90b5f8e..756a5238a 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -118,6 +118,7 @@ func temp() error { Type: "object", Properties: []*sample.Schema{ {Name: "name", Type: "string"}, + {Name: "age", Type: "integer"}, }, } @@ -158,7 +159,7 @@ func temp() error { samplingTime := time.Since(samplingStart) totalSamplingTime += samplingTime - fmt.Println("sampling time", samplingTime) + // fmt.Println("sampling time", samplingTime) // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) var outputIDs []int32 diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index fe75eb864..ea568525b 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "runtime" + "time" "github.com/ollama/ollama/model" ) @@ -21,15 +22,15 @@ type PushdownSampler struct { // graph should be built once and reused per tokenizer func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { - // start := time.Now() + start := time.Now() - // fmt.Println("--------------------------------") - // fmt.Println("PDA sampler") - // fmt.Println("--------------------------------") + fmt.Println("--------------------------------") + fmt.Println("PDA sampler") + fmt.Println("--------------------------------") var m runtime.MemStats runtime.ReadMemStats(&m) - // before := m.Alloc - // fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) + before := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) startNode, stateToNodeMap, err := BuildGraph(proc) if err != nil { @@ -40,10 +41,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { panic(err) } runtime.ReadMemStats(&m) - // after := m.Alloc - // fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) - // fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) - // fmt.Printf("Graph build time = %v\n", time.Since(start)) + after := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) + fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Graph build time = %v\n", time.Since(start)) return &PushdownSampler{ curNode: startNode, @@ -57,13 +58,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { // TODO: need to add resampling logic if the first sample was not good // greedy sample + backtrack? func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { - // fmt.Println(">>> sample:", s.curNode.State) switch s.curNode.State { case StateInString: return s.maskLogits(logits, s.curNode) case StateInListEnd: - // fmt.Println("in list end", s.braceStack) // force finish if no braces left if len(s.braceStack) == 0 { s.curNode = NewPDANode(StateTerminate) @@ -100,7 +99,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { peek := s.braceStack[len(s.braceStack)-1] if peek == rune('[') { s.curNode = s.stateToNodeMap[StateInListObjectEnd] - // fmt.Println("switching to list object end", s.curNode.State) } logits, err := s.maskLogits(logits, s.curNode) @@ -113,7 +111,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { peek := s.braceStack[len(s.braceStack)-1] if peek == rune('[') { s.curNode = s.stateToNodeMap[StateInListComma] - // fmt.Println("switching to list comma", s.curNode.State) } logits, err := s.maskLogits(logits, s.curNode) if err != nil { @@ -132,7 +129,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { return logits, nil default: - // fmt.Println("masking logits current state", s.curNode.State) + fmt.Println("masking logits current state", s.curNode.State) logits, err := s.maskLogits(logits, s.curNode) if err != nil { return nil, err @@ -142,22 +139,20 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { - // fmt.Println("current state - updating", s.curNode.State) + fmt.Println("current state - updating", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { return err } - // fmt.Println("mappedString", mappedString) + fmt.Println(">>> mappedString", mappedString) // TODO: should force closing for all braces - not doing square yet for _, r := range mappedString { if r == rune('{') { s.braceStack = append(s.braceStack, r) - // fmt.Println("pushing { brace stack", r) } if r == rune('[') { s.braceStack = append(s.braceStack, r) - // fmt.Println("pushing [ brace stack", r) } if r == rune('}') { if len(s.braceStack) == 0 { @@ -168,7 +163,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') } s.braceStack = s.braceStack[:len(s.braceStack)-1] - // fmt.Println("popping { brace stack", top) } if r == rune(']') { @@ -180,7 +174,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') } s.braceStack = s.braceStack[:len(s.braceStack)-1] - // fmt.Println("popping [ brace stack", top) } } @@ -190,7 +183,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { if !ok { return fmt.Errorf("invalid token: %q", mappedString) } - // fmt.Println("transitioning to", nextNodeState) + fmt.Println("transitioning to", nextNode.State) // TODO: add a penalty for staying in the same state too long if nextNode.State == s.curNode.State { @@ -199,7 +192,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { s.stateCounter = 0 } s.curNode = nextNode - // fmt.Println("updated curNode state", s.curNode.State) + fmt.Println("updated curNode state", s.curNode.State) } return nil } diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 309ead4aa..7e540ae9b 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -9,20 +9,20 @@ import ( ) type SOSampler struct { - schema *Schema - propIdx int - propStateMap map[string]*PDANode - pdaSampler *PushdownSampler + schema *Schema + propIdx int + propToNodeMap map[string]*PDANode + pdaSampler *PushdownSampler } func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { pdaSampler := NewPushdownSampler(proc) so := &SOSampler{ - schema: schema, - propIdx: -1, - propStateMap: make(map[string]*PDANode), - pdaSampler: pdaSampler, + schema: schema, + propIdx: -1, + propToNodeMap: make(map[string]*PDANode), + pdaSampler: pdaSampler, } so.schemaToGraph() @@ -47,7 +47,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) before := m.Alloc // TODO: still messed up - for _, node := range so.propStateMap { + for _, node := range so.propToNodeMap { // propName -> node curState := node.State fromNode := node @@ -110,7 +110,7 @@ func (s *SOSampler) schemaToGraph() { // point to end of object key node after all chars are done prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] // points to start of the key - s.propStateMap[name] = keyNode + s.propToNodeMap[name] = keyNode fmt.Println("name", name, "keyNode", keyNode.State) } } @@ -124,10 +124,11 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) { // TODO: this tracking should probably be coming from a stack to track nested objects // simple case s.propIdx++ + fmt.Println("propIdx", s.propIdx) prop := s.schema.Properties[s.propIdx] - // fmt.Println("prop", prop.Name) - s.pdaSampler.curNode = s.propStateMap[prop.Name] - // fmt.Println("changed curNode state to", s.pdaSampler.curNode.State) + fmt.Println("prop", prop.Name) + s.pdaSampler.curNode = s.propToNodeMap[prop.Name] + fmt.Println("changed curNode state to", s.pdaSampler.curNode.State) logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode) if err != nil { return nil, err