WIP simple SO working
This commit is contained in:
parent
524029cd6d
commit
d5f8670f0a
@ -118,6 +118,7 @@ func temp() error {
|
|||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: []*sample.Schema{
|
Properties: []*sample.Schema{
|
||||||
{Name: "name", Type: "string"},
|
{Name: "name", Type: "string"},
|
||||||
|
{Name: "age", Type: "integer"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ func temp() error {
|
|||||||
samplingTime := time.Since(samplingStart)
|
samplingTime := time.Since(samplingStart)
|
||||||
totalSamplingTime += samplingTime
|
totalSamplingTime += samplingTime
|
||||||
|
|
||||||
fmt.Println("sampling time", samplingTime)
|
// fmt.Println("sampling time", samplingTime)
|
||||||
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||||
|
|
||||||
var outputIDs []int32
|
var outputIDs []int32
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
@ -21,15 +22,15 @@ type PushdownSampler struct {
|
|||||||
|
|
||||||
// graph should be built once and reused per tokenizer
|
// graph should be built once and reused per tokenizer
|
||||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||||
// start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// fmt.Println("--------------------------------")
|
fmt.Println("--------------------------------")
|
||||||
// fmt.Println("PDA sampler")
|
fmt.Println("PDA sampler")
|
||||||
// fmt.Println("--------------------------------")
|
fmt.Println("--------------------------------")
|
||||||
var m runtime.MemStats
|
var m runtime.MemStats
|
||||||
runtime.ReadMemStats(&m)
|
runtime.ReadMemStats(&m)
|
||||||
// before := m.Alloc
|
before := m.Alloc
|
||||||
// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||||
|
|
||||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -40,10 +41,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
runtime.ReadMemStats(&m)
|
runtime.ReadMemStats(&m)
|
||||||
// after := m.Alloc
|
after := m.Alloc
|
||||||
// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||||
// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
// fmt.Printf("Graph build time = %v\n", time.Since(start))
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||||
|
|
||||||
return &PushdownSampler{
|
return &PushdownSampler{
|
||||||
curNode: startNode,
|
curNode: startNode,
|
||||||
@ -57,13 +58,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
// TODO: need to add resampling logic if the first sample was not good
|
// TODO: need to add resampling logic if the first sample was not good
|
||||||
// greedy sample + backtrack?
|
// greedy sample + backtrack?
|
||||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||||
// fmt.Println(">>> sample:", s.curNode.State)
|
|
||||||
switch s.curNode.State {
|
switch s.curNode.State {
|
||||||
case StateInString:
|
case StateInString:
|
||||||
return s.maskLogits(logits, s.curNode)
|
return s.maskLogits(logits, s.curNode)
|
||||||
|
|
||||||
case StateInListEnd:
|
case StateInListEnd:
|
||||||
// fmt.Println("in list end", s.braceStack)
|
|
||||||
// force finish if no braces left
|
// force finish if no braces left
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
s.curNode = NewPDANode(StateTerminate)
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
@ -100,7 +99,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
peek := s.braceStack[len(s.braceStack)-1]
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
if peek == rune('[') {
|
if peek == rune('[') {
|
||||||
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||||
// fmt.Println("switching to list object end", s.curNode.State)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
@ -113,7 +111,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
peek := s.braceStack[len(s.braceStack)-1]
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
if peek == rune('[') {
|
if peek == rune('[') {
|
||||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||||
// fmt.Println("switching to list comma", s.curNode.State)
|
|
||||||
}
|
}
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -132,7 +129,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
return logits, nil
|
return logits, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// fmt.Println("masking logits current state", s.curNode.State)
|
fmt.Println("masking logits current state", s.curNode.State)
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -142,22 +139,20 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||||
// fmt.Println("current state - updating", s.curNode.State)
|
fmt.Println("current state - updating", s.curNode.State)
|
||||||
mappedString, err := s.proc.Decode(tokenSlice)
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// fmt.Println("mappedString", mappedString)
|
fmt.Println(">>> mappedString", mappedString)
|
||||||
|
|
||||||
// TODO: should force closing for all braces - not doing square yet
|
// TODO: should force closing for all braces - not doing square yet
|
||||||
for _, r := range mappedString {
|
for _, r := range mappedString {
|
||||||
if r == rune('{') {
|
if r == rune('{') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
// fmt.Println("pushing { brace stack", r)
|
|
||||||
}
|
}
|
||||||
if r == rune('[') {
|
if r == rune('[') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
// fmt.Println("pushing [ brace stack", r)
|
|
||||||
}
|
}
|
||||||
if r == rune('}') {
|
if r == rune('}') {
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
@ -168,7 +163,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
// fmt.Println("popping { brace stack", top)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r == rune(']') {
|
if r == rune(']') {
|
||||||
@ -180,7 +174,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
// fmt.Println("popping [ brace stack", top)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +183,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid token: %q", mappedString)
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
}
|
}
|
||||||
// fmt.Println("transitioning to", nextNodeState)
|
fmt.Println("transitioning to", nextNode.State)
|
||||||
|
|
||||||
// TODO: add a penalty for staying in the same state too long
|
// TODO: add a penalty for staying in the same state too long
|
||||||
if nextNode.State == s.curNode.State {
|
if nextNode.State == s.curNode.State {
|
||||||
@ -199,7 +192,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
s.stateCounter = 0
|
s.stateCounter = 0
|
||||||
}
|
}
|
||||||
s.curNode = nextNode
|
s.curNode = nextNode
|
||||||
// fmt.Println("updated curNode state", s.curNode.State)
|
fmt.Println("updated curNode state", s.curNode.State)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
type SOSampler struct {
|
type SOSampler struct {
|
||||||
schema *Schema
|
schema *Schema
|
||||||
propIdx int
|
propIdx int
|
||||||
propStateMap map[string]*PDANode
|
propToNodeMap map[string]*PDANode
|
||||||
pdaSampler *PushdownSampler
|
pdaSampler *PushdownSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
|||||||
so := &SOSampler{
|
so := &SOSampler{
|
||||||
schema: schema,
|
schema: schema,
|
||||||
propIdx: -1,
|
propIdx: -1,
|
||||||
propStateMap: make(map[string]*PDANode),
|
propToNodeMap: make(map[string]*PDANode),
|
||||||
pdaSampler: pdaSampler,
|
pdaSampler: pdaSampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
|||||||
before := m.Alloc
|
before := m.Alloc
|
||||||
|
|
||||||
// TODO: still messed up
|
// TODO: still messed up
|
||||||
for _, node := range so.propStateMap {
|
for _, node := range so.propToNodeMap {
|
||||||
// propName -> node
|
// propName -> node
|
||||||
curState := node.State
|
curState := node.State
|
||||||
fromNode := node
|
fromNode := node
|
||||||
@ -110,7 +110,7 @@ func (s *SOSampler) schemaToGraph() {
|
|||||||
// point to end of object key node after all chars are done
|
// point to end of object key node after all chars are done
|
||||||
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||||
// points to start of the key
|
// points to start of the key
|
||||||
s.propStateMap[name] = keyNode
|
s.propToNodeMap[name] = keyNode
|
||||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -124,10 +124,11 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
// TODO: this tracking should probably be coming from a stack to track nested objects
|
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||||
// simple case
|
// simple case
|
||||||
s.propIdx++
|
s.propIdx++
|
||||||
|
fmt.Println("propIdx", s.propIdx)
|
||||||
prop := s.schema.Properties[s.propIdx]
|
prop := s.schema.Properties[s.propIdx]
|
||||||
// fmt.Println("prop", prop.Name)
|
fmt.Println("prop", prop.Name)
|
||||||
s.pdaSampler.curNode = s.propStateMap[prop.Name]
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
// fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||||
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
x
Reference in New Issue
Block a user