wip with json stuff and cleanup

This commit is contained in:
ParthSareen 2025-02-11 16:40:40 -08:00
parent 25edfa6fdb
commit aa6d5151df
10 changed files with 561 additions and 330 deletions

49
sample/constrained.go Normal file
View File

@ -0,0 +1,49 @@
package sample
import (
"github.com/ollama/ollama/model"
)
type ConstrainedSampler struct {
schema *Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, err
}
// if schema == nil {
return &ConstrainedSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
return nil, nil
}
func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
return err
}
if s.schema == nil {
return nil
}
return nil
}

32
sample/feedback.txt Normal file
View File

@ -0,0 +1,32 @@
// Feedback from code review:
// pushdown_automata.go:
// 1. The BuildGraph function is quite long and could be split into smaller, more focused functions
// 2. Consider using constants instead of magic runes like rune(-1) for sentinel values
// 3. The state machine transitions could be defined more declaratively, perhaps in a config
// 4. The stringInvalidRunes list needs to handle escape sequences properly
// 5. The graph building could be optimized to avoid duplicate nodes/transitions
// 6. Consider adding validation for max nesting depth of braces/brackets
// 7. The CreateMask function is doing a lot - could be split into smaller pieces
// 8. isRuneValid has a "garbage interface" per TODO - needs cleaner design
// pushdown_runner.go:
// 1. The Apply method has a lot of duplicated logic around EOS handling
// 2. The UpdateState method could use more granular error messages
// 3. The braceStack validation could be moved to a separate validator
// 4. Consider adding max length limits for strings/numbers
// 5. The stateCounter isn't being used effectively yet
// 6. Need to add penalties for staying in same state too long
// 7. The maskLogits function could be optimized to avoid allocations
// 8. Missing proper cleanup/reset functionality
// 9. Error handling could be more consistent throughout
// 10. Consider adding debug logging levels instead of raw fmt.Println
// General improvements needed:
// - More comprehensive testing, especially edge cases
// - Better documentation of state machine transitions
// - Performance optimization for large inputs
// - Memory usage optimization for the graph structure
// - Cleaner interfaces between components
// - More robust error handling and recovery

View File

@ -0,0 +1,11 @@
package sample
// type fusedMaskSampler struct{}
// func FusedMaskSampler() Sampler {
// return fusedMaskSampler{}
// }
// func (f fusedMaskSampler) Sample(logits []float64) (int, error) {
// return int(logits[0]), nil
// }

View File

@ -8,6 +8,19 @@ func Greedy() Sampler {
return greedy{} return greedy{}
} }
func (s greedy) Sample(t []float64) (int, error) { func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) {
return floats.MaxIdx(t), nil logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
var err error
for _, t := range transforms {
logits64, err = t.Apply(logits64)
if err != nil {
return -1, err
}
}
return floats.MaxIdx(logits64), nil
} }

View File

@ -23,7 +23,9 @@ const (
StateInColon StateInColon
StateInComma StateInComma
StateInTab StateInTab
StateInSpace StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace StateInObjSpace
StateInList StateInList
StateInListComma StateInListComma
@ -57,7 +59,9 @@ var JSONStates = []JSONState{
StateInColon, StateInColon,
StateInComma, StateInComma,
StateInTab, StateInTab,
StateInSpace, StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace, StateInObjSpace,
StateInList, StateInList,
StateInListComma, StateInListComma,
@ -107,7 +111,7 @@ func (s JSONState) String() string {
return "StateInComma" return "StateInComma"
case StateInTab: case StateInTab:
return "StateInTab" return "StateInTab"
case StateInSpace: case StateInSpaceToValue:
return "StateInSpace" return "StateInSpace"
case StateInObjSpace: case StateInObjSpace:
return "StateInObjSpace" return "StateInObjSpace"
@ -121,6 +125,8 @@ func (s JSONState) String() string {
return "StateInListEnd" return "StateInListEnd"
case StateInNewline: case StateInNewline:
return "StateInNewline" return "StateInNewline"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInNumber: case StateInNumber:
return "StateInNumber" return "StateInNumber"
case StateInNumberEnd: case StateInNumberEnd:
@ -129,6 +135,8 @@ func (s JSONState) String() string {
return "StateInStringEnd" return "StateInStringEnd"
case StateInObjectKeyEnd: case StateInObjectKeyEnd:
return "StateInObjectKeyEnd" return "StateInObjectKeyEnd"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateTerminate: case StateTerminate:
return "StateTerminate" return "StateTerminate"
case StateInObjectEnd: case StateInObjectEnd:

View File

@ -6,8 +6,35 @@ import (
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character // TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'} var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct { type PDA struct {
State JSONState State JSONState
TransitionEdges map[rune]*PDANode TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDANode MaskTokenIDToNode map[int32]*PDA
} }
func NewPDANode(state JSONState) *PDANode { func NewPDANode(state JSONState) *PDA {
return &PDANode{ return &PDA{
State: state, State: state,
TransitionEdges: make(map[rune]*PDANode), TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDANode), MaskTokenIDToNode: make(map[int32]*PDA),
} }
} }
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) { type PDAGraphBuilder struct {
stateToNodeMap := make(map[JSONState]*PDANode) proc model.TextProcessor
decodedToks []string
// TODO: make this a loop stateToNodeMap map[JSONState]*PDA
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates { for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state) stateToNodeMap[state] = NewPDANode(state)
} }
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
//new line // new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// new line end value
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// where values should be // where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead // this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace] stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) b.addValueConnections(stateToNodeMap[StateInColon])
// Leads to a value // Leads to a value
stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap) b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// Values // Values
// string node // string node
@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd] stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node // String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) b.addEnds(stateToNodeMap[StateInStringEnd])
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc // TODO: add counters for allowable number of decimals, e, E, etc
// number node // number node
for _, r := range validNumberRunes { for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
} }
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) b.addEnds(stateToNodeMap[StateInNumber])
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
// bool node stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
// list node // list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list // empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) b.addValueConnections(stateToNodeMap[StateInList])
// null node // null node
for _, r := range validNullRunes { for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull] stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
} }
addEnds(stateToNodeMap[StateInNull], stateToNodeMap) b.addEnds(stateToNodeMap[StateInNull])
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma // list comma
// should point to values // should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) b.addValueConnections(stateToNodeMap[StateInListComma])
// list object end // list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node // bool node
for _, r := range validBoolRunes { for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
} }
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap) b.addEnds(stateToNodeMap[StateInBool])
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] // space end value
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
return stateToNodeMap[StateStart], stateToNodeMap, nil b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
} }
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { func (b *PDAGraphBuilder) addEnds(node *PDA) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma] node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd] node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
} }
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString] node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
for _, r := range validNumberRunes { for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber] node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
} }
node.TransitionEdges['t'] = stateToNodeMap[StateInBool] // TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['f'] = stateToNodeMap[StateInBool] node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull] node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
} }
// TODO: tough life fr. plz fix. func (b *PDAGraphBuilder) preComputeValidStates() error {
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { for _, node := range b.stateToNodeMap {
if err := b.CreateMask(node); err != nil {
// TODO; should come from top level
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
var err error
for _, node := range stateToNodeMap {
err = CreateMask(node, proc, decodedToks)
if err != nil {
return err return err
} }
} }
return nil return nil
} }
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error { func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
for i := range decodedToks { for i := range b.decodedToks {
token := decodedToks[i] token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue continue
} }
valid := true
curNode := node curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool) consumedSpecialRunes := make(map[rune]bool)
var err error
for _, r := range token { for _, r := range token {
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if err != nil { if curNode == nil || !valid {
return err
}
if !valid {
break break
} }
} }
if valid { if valid {
// cur node allows skipping
node.MaskTokenIDToNode[int32(i)] = curNode node.MaskTokenIDToNode[int32(i)] = curNode
} }
} }
return nil return nil
} }
// TODO: garbage interface plz fix func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
if consumedSpecialRunes[r] { if consumedSpecialRunes[r] {
return false, nil, nil return nil, false
} }
specialRune := slices.Contains(stringInvalidRunes, r) specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune { if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey { if curNode.State == StateInString || curNode.State == StateInObjectKey {
return false, nil, nil return nil, false
} }
} }
@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
if nextNode, ok := curNode.TransitionEdges[r]; ok { if nextNode, ok := curNode.TransitionEdges[r]; ok {
if specialRune { if specialRune {
if curNode.State == nextNode.State { if curNode.State == nextNode.State {
return false, nil, nil return nil, false
} }
consumedSpecialRunes[r] = true consumedSpecialRunes[r] = true
} }
return true, nextNode, nil return nextNode, true
} }
// Check for sentinel value - if present, any rune is valid // Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return true, nextNode, nil return nextNode, true
} }
return false, nil, nil return nil, false
} }

View File

@ -11,17 +11,17 @@ import (
// TODO: safety in case of invalid json // TODO: safety in case of invalid json
// TODO: interfaces to cleanup with return values // TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
type PushdownSampler struct { type PushdownSampler struct {
// stateful PDAGraphBuilder
curNode *PDANode curNode *PDA
proc model.TextProcessor braceStack []rune
stateToNodeMap map[JSONState]*PDANode stateCounter uint32
braceStack []rune
stateCounter uint32
} }
// graph should be built once and reused per tokenizer // graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now() start := time.Now()
fmt.Println("--------------------------------") fmt.Println("--------------------------------")
@ -32,27 +32,38 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
before := m.Alloc before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc) vocab := proc.GetVocabulary()
if err != nil { decodedToks := make([]string, len(vocab.Values))
panic(err) for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
} }
err = PreComputeValidStates(stateToNodeMap, proc)
if err != nil { gb := &PDAGraphBuilder{
panic(err) proc: proc,
decodedToks: decodedToks,
} }
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
after := m.Alloc after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start)) fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{ return &PushdownSampler{
curNode: startNode, curNode: gb.stateToNodeMap[StateStart],
proc: proc, PDAGraphBuilder: *gb,
stateToNodeMap: stateToNodeMap, braceStack: []rune{},
braceStack: []rune{}, stateCounter: 0,
stateCounter: 0, }, nil
}
} }
// TODO: need to add resampling logic if the first sample was not good // TODO: need to add resampling logic if the first sample was not good
@ -66,14 +77,7 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
// force finish if no braces left // force finish if no braces left
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate) s.curNode = NewPDANode(StateTerminate)
for i := range logits { return forceFinish(s, logits)
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
} }
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
@ -82,18 +86,14 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
} }
return logits, nil return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd: case StateInObjectEnd:
// force finish if no braces left // force finish if no braces left
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate) s.curNode = NewPDANode(StateTerminate)
for i := range logits { return forceFinish(s, logits)
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
} }
peek := s.braceStack[len(s.braceStack)-1] peek := s.braceStack[len(s.braceStack)-1]
@ -112,22 +112,13 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
if peek == rune('[') { if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma] s.curNode = s.stateToNodeMap[StateInListComma]
} }
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return logits, nil return logits, nil
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
default: default:
fmt.Println("masking logits current state", s.curNode.State) fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode) logits, err := s.maskLogits(logits, s.curNode)
@ -138,13 +129,24 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
} }
} }
func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("current state - updating", s.curNode.State) fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice) mappedString, err := s.proc.Decode(tokenSlice)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(">>> mappedString", mappedString) fmt.Printf(">>> mappedString: %q\n", mappedString)
// TODO: should force closing for all braces - not doing square yet // TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString { for _, r := range mappedString {
@ -198,7 +200,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
} }
// greedy sample + backtrack? // greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
// Create a new slice with same length as logits, initialized to -Inf // Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float64, len(logits)) maskedLogits := make([]float64, len(logits))
for i := range maskedLogits { for i := range maskedLogits {
@ -215,4 +218,23 @@ func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64
return maskedLogits, nil return maskedLogits, nil
} }
// TODO: add penalties for string \n stuff func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) {
maxLogit := math.Inf(-1)
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float64(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@ -6,6 +6,8 @@ import (
"math" "math"
"slices" "slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat/sampleuv" "gonum.org/v1/gonum/stat/sampleuv"
) )
@ -15,33 +17,34 @@ type Transform interface {
} }
type Sampler interface { type Sampler interface {
Sample([]float64) (int, error) Sample([]float32, ...Transform) (int, error)
} }
type SamplerConfig struct { // TODO(parthsareen): potentially cache softmax values
transforms []Transform func softmax(logits []float64) []float64 {
sampler Sampler var sum float64
} tt := make([]float64, len(logits))
for i, v := range logits {
// NewSampler creates a sampler with the given transforms and sampling method tt[i] = math.Exp(v)
func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig { sum += tt[i]
return &SamplerConfig{
transforms: transforms,
sampler: sampler,
} }
floats.Scale(1/sum, tt)
return tt
} }
type Temperature float64 type Temperature float64
func (t Temperature) Apply(logits []float64) ([]float64, error) { func (t Temperature) Apply(logits []float64) ([]float64, error) {
if t == 0 {
return nil, errors.New("use Greedy sampler instead of Temperature(0)")
}
if t < 0 || t > 2 { if t < 0 || t > 2 {
return nil, errors.New("temperature must be between 0 and 2") return nil, errors.New("temperature must be between 0 and 2")
} }
temp := math.Max(float64(t), 1e-7)
// subtracting max logit to avoid under/overflow // subtracting max logit to avoid under/overflow
maxLogit := floats.Max(logits) maxLogit := slices.Max(logits)
temp := math.Max(float64(t), 1e-7)
for i := range logits { for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp logits[i] = (logits[i] - maxLogit) / temp
} }
@ -49,52 +52,41 @@ func (t Temperature) Apply(logits []float64) ([]float64, error) {
return logits, nil return logits, nil
} }
type softmax struct{} type logitMap struct {
index int
func Softmax() Transform { logit float64
return softmax{}
} }
func (softmax) Apply(logits []float64) ([]float64, error) { func logitMapComparator(a, b logitMap) int {
return computeSoftmax(logits), nil return -cmp.Compare(a.logit, b.logit)
}
// TODO: cache softmax values
func computeSoftmax(logits []float64) []float64 {
copiedLogits := make([]float64, len(logits))
copy(copiedLogits, logits)
for i := range copiedLogits {
copiedLogits[i] = math.Exp(copiedLogits[i])
}
floatSum := floats.Sum(copiedLogits)
floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits
} }
type TopK int type TopK int
// TODO(parthsareen): avoid having to check all logits after this transform
func (k TopK) Apply(logits []float64) ([]float64, error) { func (k TopK) Apply(logits []float64) ([]float64, error) {
if k <= 0 { if k <= 0 {
return nil, errors.New("k must be positive") return nil, errors.New("k must be greater than 0")
} }
if int(k) >= len(logits) { if int(k) >= len(logits) {
return logits, nil return logits, nil
} }
indices := make([]int, len(logits)) q := pq.NewWith(logitMapComparator)
for i := range indices { for i, logit := range logits {
indices[i] = i q.Enqueue(logitMap{index: i, logit: logit})
} }
// sort in descending order validLogits := make(map[int]float64)
slices.SortFunc(indices, func(i, j int) int { for range k {
return cmp.Compare(logits[j], logits[i]) logitMap, _ := q.Dequeue()
}) validLogits[logitMap.index] = logitMap.logit
}
for _, idx := range indices[k:] { for i := range logits {
logits[idx] = math.Inf(-1) if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
}
} }
return logits, nil return logits, nil
@ -107,8 +99,7 @@ func (p TopP) Apply(logits []float64) ([]float64, error) {
return nil, errors.New("p must be between 0 and 1") return nil, errors.New("p must be between 0 and 1")
} }
probs := computeSoftmax(logits) probs := softmax(logits)
indices := make([]int, len(probs)) indices := make([]int, len(probs))
for i := range indices { for i := range indices {
indices[i] = i indices[i] = i
@ -139,17 +130,11 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
return nil, errors.New("p must be between 0 and 1") return nil, errors.New("p must be between 0 and 1")
} }
probs := computeSoftmax(logits) probs := softmax(logits)
copiedProbs := make([]float64, len(probs)) threshold := slices.Max(probs) * float64(p)
copy(copiedProbs, probs)
slices.Sort(copiedProbs) for i, prob := range probs {
if prob < threshold {
maxProb := copiedProbs[len(copiedProbs)-1]
probThreshold := float64(p) * maxProb
for i := range probs {
if probs[i] < probThreshold {
logits[i] = math.Inf(-1) logits[i] = math.Inf(-1)
} }
} }
@ -157,18 +142,35 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
return logits, nil return logits, nil
} }
type weighed struct{} type weighted struct {
src rand.Source
func Weighed() Sampler {
return weighed{}
} }
// should return single value func Weighted(seed *int64) Sampler {
func (s weighed) Sample(logits []float64) (int, error) { var src rand.Source
if seed != nil {
src = rand.NewSource(uint64(*seed))
}
return weighted{src: src}
}
func (s weighted) Sample(logits []float32, transforms ...Transform) (int, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
var err error
for _, t := range transforms {
logits64, err = t.Apply(logits64)
if err != nil {
return -1, err
}
}
logitsCopy := make([]float64, 0, len(logits)) logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits)) indices := make([]int, 0, len(logits))
// the uv sampler does not support NaN values for i, logit := range logits64 {
for i, logit := range logits {
if !math.IsInf(logit, -1) { if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit) logitsCopy = append(logitsCopy, logit)
indices = append(indices, i) indices = append(indices, i)
@ -176,38 +178,13 @@ func (s weighed) Sample(logits []float64) (int, error) {
} }
if len(logitsCopy) == 0 { if len(logitsCopy) == 0 {
return -1, errors.New("no valid tokens found") return -1, errors.New("no valid logits found for weighed sampling")
} }
softmax := computeSoftmax(logitsCopy) probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(softmax, nil) w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok { if idx, ok := w.Take(); ok {
// returns the token ID
return indices[idx], nil return indices[idx], nil
} }
return -1, errors.New("weighed sampler failed") return -1, errors.New("weighed sampler failed, no valid token found")
}
// Sample applies transforms and samples a token ID
func (s *SamplerConfig) Sample(input []float32) (int, error) {
logits := make([]float64, len(input))
for i, v := range input {
logits[i] = float64(v)
}
var err error
for _, t := range s.transforms {
if t == Temperature(0) {
// early return with greedy if temperature is 0
s.sampler = Greedy()
break
}
logits, err = t.Apply(logits)
if err != nil {
return -1, err
}
}
return s.sampler.Sample(logits)
} }

View File

@ -3,116 +3,129 @@ package sample
import ( import (
"fmt" "fmt"
"math" "math"
"slices" "math/rand/v2"
"testing" "testing"
"gonum.org/v1/gonum/floats" "github.com/google/go-cmp/cmp"
) )
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
want := []float64{-14, -12, -10, -8, -6, -4, 0} want := []float64{-4, -10, 0, -14, -6, -12, -8}
if !floats.Equal(logits, want) { if diff := cmp.Diff(want, logits); diff != "" {
t.Fatalf("got: %v, want: %v", logits, want) t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
t.Fatalf("expected error for temperature=-1, got %v", logits) if err == nil {
t.Errorf("expected error for temperature=-1, got %v", logits)
} }
if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
t.Fatalf("expected error for temperature=2.1, got %v", logits) if err == nil {
t.Errorf("expected error for temperature=0, got %v", logits)
}
logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Errorf("expected error for temperature=2.1, got %v", logits)
} }
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
if !floats.Equal(probs, expectedProbs) { if diff := cmp.Diff(expectedProbs, probs); diff != "" {
t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs) t.Errorf("probs mismatch (-want +got):\n%s", diff)
} }
} }
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if !floats.Same(logits, expectedlogits) { if diff := cmp.Diff(expectedlogits, logits); diff != "" {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
_, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for k=0, got %v", logits) t.Errorf("expected error for k=0, got %v", err)
} }
logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4} expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
if !floats.Same(logits, expectedlogits) { if diff := cmp.Diff(expectedlogits, logits); diff != "" {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if !floats.Same(logits, want) { if diff := cmp.Diff(want, logits); diff != "" {
t.Fatalf("got: %v, want: %v", logits, want) t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
_, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits) t.Error("expected error for p=1.0")
} }
logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) _, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits) t.Error("expected error for p=0.0")
} }
} }
func TestMinP(t *testing.T) { func TestMinP(t *testing.T) {
logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4} want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
if !floats.Same(logits, want) { if diff := cmp.Diff(want, logits); diff != "" {
t.Fatalf("got: %v, want: %v", logits, want) t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
_, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits) t.Error("expected error for p=1.0")
} }
logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) _, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil { if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits) t.Error("expected error for p=0.0")
} }
} }
func TestWeighed(t *testing.T) { func TestWeighed(t *testing.T) {
idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)}) idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
want := 1 want := 1
if idx != want { if diff := cmp.Diff(want, idx); diff != "" {
t.Fatalf("got: %v, want: %v", idx, want) t.Errorf("index mismatch (-want +got):\n%s", diff)
} }
idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)})
idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
if err == nil { if err == nil {
t.Fatalf("expected error for no valid tokens, got %v", idx) t.Error("expected error for no valid tokens, got index", idx)
} }
} }
@ -132,27 +145,32 @@ func TestSample(t *testing.T) {
id: 3, id: 3,
callOrder: &callOrder, callOrder: &callOrder,
} }
sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy())
got, err := sampler.Sample(input) got, err := Greedy().Sample(input, mock1, mock2, mock3)
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} return
if !slices.Equal(callOrder, []int{1, 2, 3}) {
t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
} }
want := 3 // Greedy sampler should pick highest logit want := 3 // Greedy sampler should pick highest logit
if got != want { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("got %v, want %v", got, want) t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
}
_, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
if err != nil {
t.Error(err)
return
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
} }
errMock := &testTransform{ errMock := &testTransform{
returnErr: fmt.Errorf("mock error"), returnErr: fmt.Errorf("mock error"),
} }
sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy()) _, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
_, err = sampler.Sample(input)
if err == nil { if err == nil {
t.Error("Expected error from sampler") t.Error("Expected error from sampler")
} }
@ -174,14 +192,51 @@ func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
return logits, nil return logits, nil
} }
func TestSampleTemperatureZero(t *testing.T) { func BenchmarkTransform(b *testing.B) {
sampler := NewSampler([]Transform{Temperature(0)}, Greedy()) transforms := map[string]Transform{
got, err := sampler.Sample([]float32{1, 2, 3, 4}) "Temperature": Temperature(0.5),
if err != nil { "TopK": TopK(10),
t.Fatal(err) "TopP": TopP(0.9),
"MinP": MinP(0.2),
} }
want := 3 // Greedy sampler should pick highest logit index
if got != want { logits := make([]float64, 1<<16)
t.Fatalf("got: %v, want: %v", got, want) for i := range logits {
logits[i] = rand.Float64()
}
for name, transform := range transforms {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
_, err := transform.Apply(logits)
if err != nil {
b.Error(err)
}
}
})
}
}
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": Greedy(),
"Weighted": Weighted(nil),
}
logits := make([]float32, 1<<16)
for i := range logits {
logits[i] = rand.Float32()
}
for name, s := range samplers {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
if _, err := s.Sample(logits); err != nil {
b.Error(err)
}
}
})
} }
} }

View File

@ -8,27 +8,45 @@ import (
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
type SOSampler struct { type JSONSampler struct {
schema *Schema schema *Schema
propIdx int propIdx int
propToNodeMap map[string]*PDANode propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler pdaSampler *PushdownSampler
decodedToks []string decodedToks []string
} }
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
pdaSampler := NewPushdownSampler(proc) pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, err
}
so := &SOSampler{ if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema, schema: schema,
propIdx: -1, propIdx: -1,
propToNodeMap: make(map[string]*PDANode), propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler, pdaSampler: pdaSampler,
} }
so.schemaToGraph() so.schemaToGraph()
// This is prob slow // Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.GetVocabulary() vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values)) decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values { for i := range vocab.Values {
@ -40,14 +58,18 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
} }
so.decodedToks = decodedToks so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------") fmt.Println("--------------------------------")
fmt.Println("SOSampler") fmt.Println("SOSampler")
fmt.Println("--------------------------------") fmt.Println("--------------------------------")
// Benchmark this section // Benchmark this section
start := time.Now() start = time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
before := m.Alloc before = m.Alloc
// TODO: still messed up // TODO: still messed up
// TODO: recursion use case // TODO: recursion use case
@ -57,12 +79,12 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
// propName -> node // propName -> node
curState := node.State curState := node.State
fromNode := node fromNode := node
CreateMask(fromNode, proc, decodedToks) so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey { for curState == StateInStructuredKey {
// there is only one edge // there is only one edge
for r, toNode := range fromNode.TransitionEdges { for r, toNode := range fromNode.TransitionEdges {
// fmt.Println("rune", r, "edge", toNode.State) // fmt.Println("rune", r, "edge", toNode.State)
CreateMask(toNode, proc, decodedToks) so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r) fmt.Printf("created mask for %c\n", r)
curState = toNode.State curState = toNode.State
fmt.Println("next state", curState) fmt.Println("next state", curState)
@ -73,7 +95,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
} }
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
after := m.Alloc after = m.Alloc
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Mask creation time = %v\n", time.Since(start)) fmt.Printf("Mask creation time = %v\n", time.Since(start))
fmt.Println("--------------------------------") fmt.Println("--------------------------------")
@ -81,7 +103,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
return so, nil return so, nil
} }
func (s *SOSampler) schemaToGraph() { func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType() schemaType := s.schema.EffectiveType()
switch schemaType { switch schemaType {
case "object": case "object":
@ -91,18 +113,18 @@ func (s *SOSampler) schemaToGraph() {
for _, prop := range s.schema.Properties { for _, prop := range s.schema.Properties {
// name of key // name of key
name := prop.Name name := prop.Name
keyNode := &PDANode{ keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode), TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDANode), MaskTokenIDToNode: make(map[int32]*PDA),
} }
prevNode := keyNode prevNode := keyNode
for _, r := range name { for _, r := range name {
runeNode := &PDANode{ runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode), TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDANode), MaskTokenIDToNode: make(map[int32]*PDA),
} }
fmt.Println("runeNode created", runeNode.State) fmt.Println("runeNode created", runeNode.State)
fmt.Printf("runeNode created %c\n", r) fmt.Printf("runeNode created %c\n", r)
@ -117,9 +139,14 @@ func (s *SOSampler) schemaToGraph() {
fmt.Println("name", name, "keyNode", keyNode.State) fmt.Println("name", name, "keyNode", keyNode.State)
} }
} }
// TODO: do values + recursion
} }
func (s *SOSampler) Apply(logits []float64) ([]float64, error) { func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State { switch s.pdaSampler.curNode.State {
// doesnt account for multi rune case // doesnt account for multi rune case
case StateInObjectKey: case StateInObjectKey:
@ -148,17 +175,18 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
// todo: if i incremenet propidx then i know im in last value as well // todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State { switch s.pdaSampler.curNode.State {
case StateInObjectEnd: case StateInObjectEnd:
fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State) fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode) s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate) s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++ s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd: case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State) fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',') delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode) s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks) s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++ s.propIdx++
} }
} }
@ -167,12 +195,17 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
} }
func (s *SOSampler) UpdateState(tokenSlice []int32) error { func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
err := s.pdaSampler.UpdateState(tokenSlice) err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil { if err != nil {
return err return err
} }
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return nil
}
switch s.pdaSampler.curNode.State { switch s.pdaSampler.curNode.State {
case StateInObjectKey: case StateInObjectKey:
s.propIdx++ s.propIdx++