This commit is contained in:
ParthSareen 2025-02-24 17:39:01 -08:00
parent ffd6428758
commit a4265c278a
7 changed files with 215 additions and 133 deletions

View File

@ -443,6 +443,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
s.lc.Synchronize() s.lc.Synchronize()
} }
var totalSamplingTime time.Duration
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
@ -477,8 +478,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
// sample a token // sample a token
samplingStart := time.Now()
token := seq.samplingCtx.Sample(s.lc, seq.iBatch) token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
seq.samplingCtx.Accept(token, true) seq.samplingCtx.Accept(token, true)
samplingTime := time.Since(samplingStart)
totalSamplingTime += samplingTime
slog.Info("sampling time", "time", samplingTime)
piece := s.model.TokenToPiece(token) piece := s.model.TokenToPiece(token)
seq.numPredicted++ seq.numPredicted++
@ -635,6 +640,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed) samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar samplingParams.Grammar = req.Grammar
start := time.Now()
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.NumPredict,
stop: req.Stop, stop: req.Stop,
@ -642,6 +648,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
}) })
slog.Info("new sequence created", "duration", time.Since(start))
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return return

View File

@ -28,7 +28,7 @@ var args struct {
} }
func temp() error { func temp() error {
start := time.Now() // start := time.Now()
flag.IntVar(&args.n, "n", 10, "number of samples") flag.IntVar(&args.n, "n", 10, "number of samples")
flag.BoolVar(&args.debug, "debug", false, "enable debug logging") flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
flag.StringVar(&args.image, "image", "", "path to image file") flag.StringVar(&args.image, "image", "", "path to image file")
@ -106,10 +106,12 @@ func temp() error {
} }
} }
// simple schema // Schema for a list of friends with their info
// This schema maps to JSON like: // Maps to JSON like:
// { // {
// "name": "some string value" // "name": "string",
// "age": integer,
// "is_available": boolean
// } // }
schema := &sample.Schema{ schema := &sample.Schema{
Name: "root", Name: "root",
@ -117,20 +119,24 @@ func temp() error {
Properties: []*sample.Schema{ Properties: []*sample.Schema{
{Name: "name", Type: "string"}, {Name: "name", Type: "string"},
{Name: "age", Type: "integer"}, {Name: "age", Type: "integer"},
{Name: "is_student", Type: "boolean"}, {Name: "is_available", Type: "boolean"},
// {Name: "is_student", Type: "boolean"},
}, },
} }
// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) // fmt.Println("schema", schema)
pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor)) // schema = nil
jsonTransform, err := sample.NewJSONSampler(m.(model.TextProcessor), schema)
if err != nil { if err != nil {
return err return err
} }
transforms := []sample.Transform{
jsonTransform,
}
var offset int var offset int
var stringBuffer string var stringBuffer string
var ttft time.Duration // var ttft time.Duration
var totalSamplingTime time.Duration var totalSamplingTime time.Duration
count := 0 count := 0
for range args.n { for range args.n {
@ -139,24 +145,9 @@ func temp() error {
return err return err
} }
// f64s := make([]float64, len(f32s))
// for i, f32 := range f32s {
// f64s[i] = float64(f32)
// }
// samplers := []sample.Transform{
// pushdownSampler,
// sample.Weighed(),
// sample.TopP(0.9),
// sample.Weighed(),
// sample.Greedy(),
// }
transforms := []sample.Transform{
pushdownSampler,
}
samplingStart := time.Now() samplingStart := time.Now()
sampler := sample.NewSampler(transforms, sample.Greedy()) sampler := sample.Greedy()
sampledIdx, err := sampler.Sample(logits.Floats()) sampledIdx, err := sampler.Sample(logits.Floats(), transforms...)
if err != nil { if err != nil {
return err return err
} }
@ -164,7 +155,7 @@ func temp() error {
samplingTime := time.Since(samplingStart) samplingTime := time.Since(samplingStart)
totalSamplingTime += samplingTime totalSamplingTime += samplingTime
fmt.Println("sampling time", samplingTime) // fmt.Println("sampling time", samplingTime)
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
var outputIDs []int32 var outputIDs []int32
@ -184,10 +175,10 @@ func temp() error {
return err return err
} }
if ttft == 0 { // if ttft == 0 {
ttft = time.Since(start) // ttft = time.Since(start)
fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds()) // fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
} // }
// fmt.Printf("--- token: %q\n", s) // fmt.Printf("--- token: %q\n", s)
// fmt.Printf("--- outputIDs: %v\n", outputIDs) // fmt.Printf("--- outputIDs: %v\n", outputIDs)
@ -195,7 +186,7 @@ func temp() error {
count++ count++
fmt.Println("--- stringBuffer", stringBuffer) fmt.Println("--- stringBuffer", stringBuffer)
err = pushdownSampler.UpdateState(outputIDs) outputIDs, err = jsonTransform.UpdateState(outputIDs)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,49 +0,0 @@
package sample
import (
"github.com/ollama/ollama/model"
)
type ConstrainedSampler struct {
schema *Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, err
}
// if schema == nil {
return &ConstrainedSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
return nil, nil
}
func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
return err
}
if s.schema == nil {
return nil
}
return nil
}

View File

@ -41,6 +41,7 @@ const (
StateTerminate StateTerminate
StateInObjectEnd StateInObjectEnd
StateTransitioningToTerminate StateTransitioningToTerminate
StateInListStartJSON
) )
var JSONStates = []JSONState{ var JSONStates = []JSONState{
@ -48,6 +49,7 @@ var JSONStates = []JSONState{
StateInObject, StateInObject,
StateInObjectKey, StateInObjectKey,
StateInStructuredKey, StateInStructuredKey,
StateInStructuredValue,
StateNewline, StateNewline,
StateTab, StateTab,
StateSpace, StateSpace,
@ -63,6 +65,7 @@ var JSONStates = []JSONState{
StateInSpaceEndValue, StateInSpaceEndValue,
StateInNewlineEndValue, StateInNewlineEndValue,
StateInObjSpace, StateInObjSpace,
StateInListStartJSON,
StateInList, StateInList,
StateInListComma, StateInListComma,
StateInValue, StateInValue,
@ -89,6 +92,8 @@ func (s JSONState) String() string {
return "StateInObjectKey" return "StateInObjectKey"
case StateInStructuredKey: case StateInStructuredKey:
return "StateInStructuredKey" return "StateInStructuredKey"
case StateInStructuredValue:
return "StateInStructuredValue"
case StateNewline: case StateNewline:
return "StateNewline" return "StateNewline"
case StateTab: case StateTab:
@ -112,21 +117,27 @@ func (s JSONState) String() string {
case StateInTab: case StateInTab:
return "StateInTab" return "StateInTab"
case StateInSpaceToValue: case StateInSpaceToValue:
return "StateInSpace" return "StateInSpaceToValue"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInObjSpace: case StateInObjSpace:
return "StateInObjSpace" return "StateInObjSpace"
case StateInList: case StateInList:
return "StateInList" return "StateInList"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInListComma: case StateInListComma:
return "StateInListComma" return "StateInListComma"
case StateInValue:
return "StateInValue"
case StateInValueEnd:
return "StateInValueEnd"
case StateInListEnd: case StateInListEnd:
return "StateInListEnd" return "StateInListEnd"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInNewline: case StateInNewline:
return "StateInNewline" return "StateInNewline"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInNumber: case StateInNumber:
return "StateInNumber" return "StateInNumber"
case StateInNumberEnd: case StateInNumberEnd:
@ -135,12 +146,14 @@ func (s JSONState) String() string {
return "StateInStringEnd" return "StateInStringEnd"
case StateInObjectKeyEnd: case StateInObjectKeyEnd:
return "StateInObjectKeyEnd" return "StateInObjectKeyEnd"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateTerminate: case StateTerminate:
return "StateTerminate" return "StateTerminate"
case StateInObjectEnd: case StateInObjectEnd:
return "StateInObjectEnd" return "StateInObjectEnd"
case StateTransitioningToTerminate:
return "StateTransitioningToTerminate"
case StateInListStartJSON:
return "StateInListStartJSON"
default: default:
return fmt.Sprintf("Unknown state: %d", s) return fmt.Sprintf("Unknown state: %d", s)
} }

View File

@ -37,8 +37,10 @@ Key JSON rules to consider:
// TODO: / should be valid but an escape character // TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var (
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
)
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'} var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
@ -61,9 +63,10 @@ func NewPDANode(state JSONState) *PDA {
} }
type PDAGraphBuilder struct { type PDAGraphBuilder struct {
proc model.TextProcessor proc model.TextProcessor
decodedToks []string decodedToks []string
stateToNodeMap map[JSONState]*PDA stateToNodeMap map[JSONState]*PDA
tokenToStatesMap map[int32][]JSONState
} }
func (b *PDAGraphBuilder) BuildGraph() error { func (b *PDAGraphBuilder) BuildGraph() error {
@ -73,20 +76,26 @@ func (b *PDAGraphBuilder) BuildGraph() error {
} }
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
// TODO: update naming here - and revisit values
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// new line // new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
// new line end value // new line end value
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
@ -108,6 +117,8 @@ func (b *PDAGraphBuilder) BuildGraph() error {
// where values should be // where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead // this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
@ -117,6 +128,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap) addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
// Values // Values
// string node // string node
@ -125,7 +137,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
// String end node // String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc // TODO: add counters for allowable number of decimals, e, E, etc
@ -134,7 +146,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
} }
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node // list node
@ -142,10 +154,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// early end
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// list end node // list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -166,6 +180,9 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end // list object end
@ -180,7 +197,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
} }
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap) addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node // comma node
@ -188,9 +205,11 @@ func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// todo: review this space transition
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
// space end value // space end value
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] // stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@ -221,6 +240,12 @@ func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
func (b *PDAGraphBuilder) preComputeValidStates() error { func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap { for _, node := range b.stateToNodeMap {
// if node.State == StateInObjectKey {
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
// fmt.Println("copying string mask to object key mask")
// }
// }
if err := b.CreateMask(node); err != nil { if err := b.CreateMask(node); err != nil {
return err return err
} }
@ -228,6 +253,20 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
return nil return nil
} }
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
// TODO: make can be somewhere else too
b.tokenToStatesMap = make(map[int32][]JSONState)
for i, t := range b.decodedToks {
for _, r := range t {
if r == '"' {
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
}
}
}
return nil
}
// TODO: the mask for obj key and string should be the same?
func (b *PDAGraphBuilder) CreateMask(node *PDA) error { func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil { if node == nil {
return fmt.Errorf("node cannot be nil") return fmt.Errorf("node cannot be nil")

View File

@ -10,9 +10,13 @@ import (
) )
// TODO: safety in case of invalid json // TODO: safety in case of invalid json
// TODO: partial JSON matching?
// TODO: interfaces to cleanup with return values // TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler // TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff // TODO: add penalties for string \n stuff
// TODO: minimize number of fwd passes if there is only one match
// TODO: greedy sample initially and then backtrack if no match
type PushdownSampler struct { type PushdownSampler struct {
PDAGraphBuilder PDAGraphBuilder
curNode *PDA curNode *PDA
@ -140,16 +144,24 @@ func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
return logits, nil return logits, nil
} }
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
fmt.Println("current state - updating", s.curNode.State) fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice) mappedString, err := s.proc.Decode(tokenSlice)
if err != nil { if err != nil {
return err return nil, err
} }
fmt.Printf(">>> mappedString: %q\n", mappedString) fmt.Printf(">>> mappedString: %q\n", mappedString)
// TODO: should force closing for all braces - not doing square yet // flag := -1
// endBraceRunes := []rune{'}', ']'}
for _, r := range mappedString { for _, r := range mappedString {
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
// // flag = i
// break
// }
if r == rune('{') { if r == rune('{') {
s.braceStack = append(s.braceStack, r) s.braceStack = append(s.braceStack, r)
} }
@ -158,32 +170,36 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
} }
if r == rune('}') { if r == rune('}') {
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
return fmt.Errorf("stack is empty, extra closing brace %c", r) return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
} }
top := s.braceStack[len(s.braceStack)-1] top := s.braceStack[len(s.braceStack)-1]
if top != rune('{') { if top != rune('{') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
} }
s.braceStack = s.braceStack[:len(s.braceStack)-1] s.braceStack = s.braceStack[:len(s.braceStack)-1]
} }
if r == rune(']') { if r == rune(']') {
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
return fmt.Errorf("stack is empty, extra closing brace %c", r) return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
} }
top := s.braceStack[len(s.braceStack)-1] top := s.braceStack[len(s.braceStack)-1]
if top != rune('[') { if top != rune('[') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
} }
s.braceStack = s.braceStack[:len(s.braceStack)-1] s.braceStack = s.braceStack[:len(s.braceStack)-1]
} }
} }
// if flag != -1 {
// tokenSlice = tokenSlice[:flag]
// }
// fmt.Println("flag!", flag)
for _, tokenID := range tokenSlice { for _, tokenID := range tokenSlice {
// transition to the next node // transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok { if !ok {
return fmt.Errorf("invalid token: %q", mappedString) return nil, fmt.Errorf("invalid token: %q", mappedString)
} }
fmt.Println("transitioning to", nextNode.State) fmt.Println("transitioning to", nextNode.State)
@ -196,12 +212,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
s.curNode = nextNode s.curNode = nextNode
fmt.Println("updated curNode state", s.curNode.State) fmt.Println("updated curNode state", s.curNode.State)
} }
return nil return tokenSlice, nil
} }
// greedy sample + backtrack? // greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) { func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
// Create a new slice with same length as logits, initialized to -Inf // Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float64, len(logits)) maskedLogits := make([]float64, len(logits))
for i := range maskedLogits { for i := range maskedLogits {

View File

@ -35,7 +35,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
}, nil }, nil
} }
fmt.Println("schema not nil") // fmt.Println("schema not nil")
so := &JSONSampler{ so := &JSONSampler{
schema: schema, schema: schema,
propIdx: -1, propIdx: -1,
@ -87,7 +87,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
for curState == StateInStructuredKey { for curState == StateInStructuredKey {
// there is only one edge // there is only one edge
for r, toNode := range fromNode.TransitionEdges { for r, toNode := range fromNode.TransitionEdges {
// fmt.Println("rune", r, "edge", toNode.State) fmt.Println("rune", r, "edge", toNode.State)
so.pdaSampler.CreateMask(toNode) so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r) fmt.Printf("created mask for %c\n", r)
curState = toNode.State curState = toNode.State
@ -96,13 +96,27 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
fromNode = toNode fromNode = toNode
} }
} }
if curState != StateInColon {
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
}
// so.pdaSampler.CreateMask(fromNode)
fromNode = fromNode.TransitionEdges[' ']
so.pdaSampler.CreateMask(fromNode)
curState = fromNode.State
for _, toNode := range fromNode.TransitionEdges {
fmt.Println("toNode", toNode.State)
}
} }
runtime.ReadMemStats(&m) // runtime.ReadMemStats(&m)
after = m.Alloc // after = m.Alloc
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Mask creation time = %v\n", time.Since(start)) // fmt.Printf("Mask creation time = %v\n", time.Since(start))
fmt.Println("--------------------------------") // fmt.Println("--------------------------------")
return so, nil return so, nil
} }
@ -130,14 +144,66 @@ func (s *JSONSampler) schemaToGraph() {
TransitionEdges: make(map[rune]*PDA), TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA), MaskTokenIDToNode: make(map[int32]*PDA),
} }
fmt.Println("runeNode created", runeNode.State) // fmt.Println("runeNode created", runeNode.State)
fmt.Printf("runeNode created %c\n", r) // fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map // since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode prevNode = runeNode
} }
// point to end of object key node after all chars are done // point to end of object key node after all chars are done
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// link to value node
// Create a node for the end of the key (after the closing quote)
stringEndNode := &PDA{
State: StateInStructuredKey,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges['"'] = stringEndNode
prevNode = stringEndNode
// Add transition for colon after key
colonNode := &PDA{
State: StateInColon,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[':'] = colonNode
prevNode = colonNode
// Add transition for space after colon
spaceNode := &PDA{
State: StateInSpaceToValue,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[' '] = spaceNode
prevNode = spaceNode
value := prop.Type
switch value {
case "object":
fmt.Println("object under key: ", name)
case "array":
fmt.Println("array under key: ", name)
case "string":
fmt.Println("string under key: ", name)
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
case "number":
fmt.Println("number under key: ", name)
for _, r := range validNumberRunes {
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
}
case "boolean":
fmt.Println("boolean under key: ", name)
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
}
// points to start of the key // points to start of the key
s.propToNodeMap[name] = keyNode s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State) fmt.Println("name", name, "keyNode", keyNode.State)
@ -152,7 +218,7 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
} }
switch s.pdaSampler.curNode.State { switch s.pdaSampler.curNode.State {
// doesnt account for multi rune case // TODO: doesnt account for multi rune case
case StateInObjectKey: case StateInObjectKey:
if s.propIdx > len(s.schema.Properties)-1 { if s.propIdx > len(s.schema.Properties)-1 {
return nil, fmt.Errorf("propIdx out of bounds") return nil, fmt.Errorf("propIdx out of bounds")
@ -196,18 +262,17 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
} }
return s.pdaSampler.Apply(logits) return s.pdaSampler.Apply(logits)
} }
} }
func (s *JSONSampler) UpdateState(tokenSlice []int32) error { func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
err := s.pdaSampler.UpdateState(tokenSlice) tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil { if err != nil {
return err return nil, err
} }
if s.schema == nil { if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling // Don't need to update state for unconstrained JSON sampling
return nil return tokenSlice, nil
} }
switch s.pdaSampler.curNode.State { switch s.pdaSampler.curNode.State {
@ -217,14 +282,15 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
prop := s.schema.Properties[s.propIdx] prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name) fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name] s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
str, err := s.pdaSampler.proc.Decode(tokenSlice) // TODO: this does not work - mike
if err != nil { // str, err := s.pdaSampler.proc.Decode(tokenSlice)
return err // if err != nil {
} // return nil, err
fmt.Println("str", str) // }
// fmt.Println("str", str)
return nil return tokenSlice, nil
default: default:
return nil return tokenSlice, nil
} }
} }