WIP simple SO working

This commit is contained in:
ParthSareen 2025-02-03 11:40:06 -08:00
parent 524029cd6d
commit d5f8670f0a
3 changed files with 32 additions and 37 deletions

View File

@ -118,6 +118,7 @@ func temp() error {
Type: "object", Type: "object",
Properties: []*sample.Schema{ Properties: []*sample.Schema{
{Name: "name", Type: "string"}, {Name: "name", Type: "string"},
{Name: "age", Type: "integer"},
}, },
} }
@ -158,7 +159,7 @@ func temp() error {
samplingTime := time.Since(samplingStart) samplingTime := time.Since(samplingStart)
totalSamplingTime += samplingTime totalSamplingTime += samplingTime
fmt.Println("sampling time", samplingTime) // fmt.Println("sampling time", samplingTime)
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
var outputIDs []int32 var outputIDs []int32

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"math" "math"
"runtime" "runtime"
"time"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
@ -21,15 +22,15 @@ type PushdownSampler struct {
// graph should be built once and reused per tokenizer // graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
// start := time.Now() start := time.Now()
// fmt.Println("--------------------------------") fmt.Println("--------------------------------")
// fmt.Println("PDA sampler") fmt.Println("PDA sampler")
// fmt.Println("--------------------------------") fmt.Println("--------------------------------")
var m runtime.MemStats var m runtime.MemStats
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
// before := m.Alloc before := m.Alloc
// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc) startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil { if err != nil {
@ -40,10 +41,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
panic(err) panic(err)
} }
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
// after := m.Alloc after := m.Alloc
// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Graph build time = %v\n", time.Since(start)) fmt.Printf("Graph build time = %v\n", time.Since(start))
return &PushdownSampler{ return &PushdownSampler{
curNode: startNode, curNode: startNode,
@ -57,13 +58,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
// TODO: need to add resampling logic if the first sample was not good // TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack? // greedy sample + backtrack?
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
// fmt.Println(">>> sample:", s.curNode.State)
switch s.curNode.State { switch s.curNode.State {
case StateInString: case StateInString:
return s.maskLogits(logits, s.curNode) return s.maskLogits(logits, s.curNode)
case StateInListEnd: case StateInListEnd:
// fmt.Println("in list end", s.braceStack)
// force finish if no braces left // force finish if no braces left
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate) s.curNode = NewPDANode(StateTerminate)
@ -100,7 +99,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
peek := s.braceStack[len(s.braceStack)-1] peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') { if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd] s.curNode = s.stateToNodeMap[StateInListObjectEnd]
// fmt.Println("switching to list object end", s.curNode.State)
} }
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
@ -113,7 +111,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
peek := s.braceStack[len(s.braceStack)-1] peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') { if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma] s.curNode = s.stateToNodeMap[StateInListComma]
// fmt.Println("switching to list comma", s.curNode.State)
} }
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
if err != nil { if err != nil {
@ -132,7 +129,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
return logits, nil return logits, nil
default: default:
// fmt.Println("masking logits current state", s.curNode.State) fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
if err != nil { if err != nil {
return nil, err return nil, err
@ -142,22 +139,20 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
} }
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("current state - updating", s.curNode.State) fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice) mappedString, err := s.proc.Decode(tokenSlice)
if err != nil { if err != nil {
return err return err
} }
// fmt.Println("mappedString", mappedString) fmt.Println(">>> mappedString", mappedString)
// TODO: should force closing for all braces - not doing square yet // TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString { for _, r := range mappedString {
if r == rune('{') { if r == rune('{') {
s.braceStack = append(s.braceStack, r) s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing { brace stack", r)
} }
if r == rune('[') { if r == rune('[') {
s.braceStack = append(s.braceStack, r) s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing [ brace stack", r)
} }
if r == rune('}') { if r == rune('}') {
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
@ -168,7 +163,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
} }
s.braceStack = s.braceStack[:len(s.braceStack)-1] s.braceStack = s.braceStack[:len(s.braceStack)-1]
// fmt.Println("popping { brace stack", top)
} }
if r == rune(']') { if r == rune(']') {
@ -180,7 +174,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
} }
s.braceStack = s.braceStack[:len(s.braceStack)-1] s.braceStack = s.braceStack[:len(s.braceStack)-1]
// fmt.Println("popping [ brace stack", top)
} }
} }
@ -190,7 +183,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
if !ok { if !ok {
return fmt.Errorf("invalid token: %q", mappedString) return fmt.Errorf("invalid token: %q", mappedString)
} }
// fmt.Println("transitioning to", nextNodeState) fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long // TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State { if nextNode.State == s.curNode.State {
@ -199,7 +192,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
s.stateCounter = 0 s.stateCounter = 0
} }
s.curNode = nextNode s.curNode = nextNode
// fmt.Println("updated curNode state", s.curNode.State) fmt.Println("updated curNode state", s.curNode.State)
} }
return nil return nil
} }

View File

@ -9,20 +9,20 @@ import (
) )
type SOSampler struct { type SOSampler struct {
schema *Schema schema *Schema
propIdx int propIdx int
propStateMap map[string]*PDANode propToNodeMap map[string]*PDANode
pdaSampler *PushdownSampler pdaSampler *PushdownSampler
} }
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
pdaSampler := NewPushdownSampler(proc) pdaSampler := NewPushdownSampler(proc)
so := &SOSampler{ so := &SOSampler{
schema: schema, schema: schema,
propIdx: -1, propIdx: -1,
propStateMap: make(map[string]*PDANode), propToNodeMap: make(map[string]*PDANode),
pdaSampler: pdaSampler, pdaSampler: pdaSampler,
} }
so.schemaToGraph() so.schemaToGraph()
@ -47,7 +47,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
before := m.Alloc before := m.Alloc
// TODO: still messed up // TODO: still messed up
for _, node := range so.propStateMap { for _, node := range so.propToNodeMap {
// propName -> node // propName -> node
curState := node.State curState := node.State
fromNode := node fromNode := node
@ -110,7 +110,7 @@ func (s *SOSampler) schemaToGraph() {
// point to end of object key node after all chars are done // point to end of object key node after all chars are done
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// points to start of the key // points to start of the key
s.propStateMap[name] = keyNode s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State) fmt.Println("name", name, "keyNode", keyNode.State)
} }
} }
@ -124,10 +124,11 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
// TODO: this tracking should probably be coming from a stack to track nested objects // TODO: this tracking should probably be coming from a stack to track nested objects
// simple case // simple case
s.propIdx++ s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx] prop := s.schema.Properties[s.propIdx]
// fmt.Println("prop", prop.Name) fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propStateMap[prop.Name] s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
// fmt.Println("changed curNode state to", s.pdaSampler.curNode.State) fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode) logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil { if err != nil {
return nil, err return nil, err