prototyping
This commit is contained in:
parent
1fd9967558
commit
5ec6bb52a0
@ -32,6 +32,7 @@ type TextProcessor interface {
|
|||||||
Encode(s string, addSpecial bool) ([]int32, error)
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(int32, Special) bool
|
Is(int32, Special) bool
|
||||||
|
Vocab() *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
|
@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Vocab() *Vocabulary {
|
||||||
|
return spm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||||
return func(yield func(string) bool) {
|
return func(yield func(string) bool) {
|
||||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||||
|
@ -468,6 +468,20 @@ func (s *Server) processBatch() error {
|
|||||||
return fmt.Errorf("failed to sample token: %w", err)
|
return fmt.Errorf("failed to sample token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if seq.sampler.JSONSampler != nil {
|
||||||
|
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seq.sampler.PythonSampler != nil {
|
||||||
|
err = seq.sampler.PythonSampler.UpdateState(token)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
@ -562,6 +576,22 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
|
||||||
|
// if err != nil {
|
||||||
|
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// jsonSampler = nil
|
||||||
|
// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
|
||||||
|
// pythonSampler := &sample.PythonSampler{}
|
||||||
|
// functions := []sample.PythonFunction{
|
||||||
|
// {
|
||||||
|
// Name: "add_two_strings",
|
||||||
|
// Args: []string{"s1", "s2"},
|
||||||
|
// Types: []string{"string", "string"},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
req.Options.Temperature,
|
req.Options.Temperature,
|
||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
@ -569,6 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
|
53
sample/gtf.go
Normal file
53
sample/gtf.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
var DefaultGrammar = map[string]string{
|
||||||
|
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
|
||||||
|
"null": `"null"`,
|
||||||
|
"object": `"{" (kv ("," kv)*)? "}"`,
|
||||||
|
"array": `"[" (value ("," value)*)? "]"`,
|
||||||
|
"kv": `string ":" value`,
|
||||||
|
"integer": `"0" | [1-9] [0-9]*`,
|
||||||
|
"number": `"-"? integer frac? exp?`,
|
||||||
|
"frac": `"." [0-9]+`,
|
||||||
|
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
|
||||||
|
"string": `"\"" char* "\""`,
|
||||||
|
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
|
||||||
|
"char": `[^"\\] | escape`,
|
||||||
|
"space": `(" " | "\t" | "\n" | "\r")*`,
|
||||||
|
"hex": `[0-9] | [a-f] | [A-F]`,
|
||||||
|
"boolean": `"true" | "false"`,
|
||||||
|
"value": `object | array | string | number | boolean | "null"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
const jsonString = `object | array`
|
||||||
|
|
||||||
|
type StateMachine struct {
|
||||||
|
states map[rune]State
|
||||||
|
}
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
NextStates []string
|
||||||
|
// bitmask?
|
||||||
|
Mask []bool
|
||||||
|
IsTerminal bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
|
||||||
|
states := make(map[rune]State)
|
||||||
|
|
||||||
|
var cumu string
|
||||||
|
flag := false
|
||||||
|
for _, r := range startRule {
|
||||||
|
if r == '"' {
|
||||||
|
flag = !flag
|
||||||
|
}
|
||||||
|
if flag {
|
||||||
|
cumu += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sm := &StateMachine{
|
||||||
|
states: states,
|
||||||
|
}
|
||||||
|
return sm
|
||||||
|
}
|
138
sample/gtf_test.go
Normal file
138
sample/gtf_test.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGrammarParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grammar map[string]string
|
||||||
|
startRule string
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple object",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"object": `"{" "}"`,
|
||||||
|
},
|
||||||
|
startRule: "object",
|
||||||
|
input: "{}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple array",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"array": `"[" "]"`,
|
||||||
|
},
|
||||||
|
startRule: "array",
|
||||||
|
input: "[]",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "character class",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"digit": `[0-9]`,
|
||||||
|
},
|
||||||
|
startRule: "digit",
|
||||||
|
input: "5",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alternation",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"bool": `"true" | "false"`,
|
||||||
|
},
|
||||||
|
startRule: "bool",
|
||||||
|
input: "true",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repetition",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"digits": `[0-9]+`,
|
||||||
|
},
|
||||||
|
startRule: "digits",
|
||||||
|
input: "123",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested rules",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"value": `object | array`,
|
||||||
|
"object": `"{" "}"`,
|
||||||
|
"array": `"[" "]"`,
|
||||||
|
},
|
||||||
|
startRule: "value",
|
||||||
|
input: "{}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := NewParser(tt.grammar)
|
||||||
|
machine, err := parser.Parse(tt.startRule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matcher := NewMatcher(machine)
|
||||||
|
got, err := matcher.Match(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Match() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONGrammar(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty object", "{}", true},
|
||||||
|
{"empty array", "[]", true},
|
||||||
|
{"simple string", `"hello"`, true},
|
||||||
|
{"simple number", "123", true},
|
||||||
|
{"simple boolean", "true", true},
|
||||||
|
{"simple null", "null", true},
|
||||||
|
{"object with string", `{"key": "value"}`, true},
|
||||||
|
{"array with numbers", "[1, 2, 3]", true},
|
||||||
|
{"nested object", `{"obj": {"key": "value"}}`, true},
|
||||||
|
{"nested array", `[1, [2, 3], 4]`, true},
|
||||||
|
{"invalid object", "{", false},
|
||||||
|
{"invalid array", "[1, 2", false},
|
||||||
|
{"invalid string", `"hello`, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := NewParser(DefaultGrammar)
|
||||||
|
machine, err := parser.Parse("value")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matcher := NewMatcher(machine)
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := matcher.Match(tt.input)
|
||||||
|
if tt.want {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Match() error = %v", err)
|
||||||
|
}
|
||||||
|
if !got {
|
||||||
|
t.Errorf("Match() = false, want true")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil && got {
|
||||||
|
t.Errorf("Match() = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
160
sample/json_types.go
Normal file
160
sample/json_types.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateStart JSONState = iota
|
||||||
|
StateInObject
|
||||||
|
StateInObjectKey
|
||||||
|
StateInStructuredKey
|
||||||
|
StateInStructuredValue
|
||||||
|
StateNewline
|
||||||
|
StateTab
|
||||||
|
StateSpace
|
||||||
|
StateInString
|
||||||
|
StateInInt
|
||||||
|
StateInFloat
|
||||||
|
StateInBool
|
||||||
|
StateInNull
|
||||||
|
StateInColon
|
||||||
|
StateInComma
|
||||||
|
StateInTab
|
||||||
|
StateInSpaceToValue
|
||||||
|
StateInSpaceEndValue
|
||||||
|
StateInNewlineEndValue
|
||||||
|
StateInObjSpace
|
||||||
|
StateInList
|
||||||
|
StateInListComma
|
||||||
|
StateInValue
|
||||||
|
StateInValueEnd
|
||||||
|
StateInListEnd
|
||||||
|
StateInListObjectEnd
|
||||||
|
StateInNewline
|
||||||
|
StateInNumber
|
||||||
|
StateInNumberEnd
|
||||||
|
StateInStringEnd
|
||||||
|
StateInObjectKeyEnd
|
||||||
|
StateTerminate
|
||||||
|
StateInObjectEnd
|
||||||
|
StateTransitioningToTerminate
|
||||||
|
StateInListStartJSON
|
||||||
|
)
|
||||||
|
|
||||||
|
var JSONStates = []JSONState{
|
||||||
|
StateStart,
|
||||||
|
StateInObject,
|
||||||
|
StateInObjectKey,
|
||||||
|
StateInStructuredKey,
|
||||||
|
StateInStructuredValue,
|
||||||
|
StateNewline,
|
||||||
|
StateTab,
|
||||||
|
StateSpace,
|
||||||
|
StateInString,
|
||||||
|
StateInInt,
|
||||||
|
StateInFloat,
|
||||||
|
StateInBool,
|
||||||
|
StateInNull,
|
||||||
|
StateInColon,
|
||||||
|
StateInComma,
|
||||||
|
StateInTab,
|
||||||
|
StateInSpaceToValue,
|
||||||
|
StateInSpaceEndValue,
|
||||||
|
StateInNewlineEndValue,
|
||||||
|
StateInObjSpace,
|
||||||
|
StateInListStartJSON,
|
||||||
|
StateInList,
|
||||||
|
StateInListComma,
|
||||||
|
StateInValue,
|
||||||
|
StateInValueEnd,
|
||||||
|
StateInListEnd,
|
||||||
|
StateInListObjectEnd,
|
||||||
|
StateInNewline,
|
||||||
|
StateInNumber,
|
||||||
|
StateInNumberEnd,
|
||||||
|
StateInStringEnd,
|
||||||
|
StateInObjectKeyEnd,
|
||||||
|
StateTerminate,
|
||||||
|
StateInObjectEnd,
|
||||||
|
StateTransitioningToTerminate,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONState) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateStart:
|
||||||
|
return "StateStart"
|
||||||
|
case StateInObject:
|
||||||
|
return "StateInObject"
|
||||||
|
case StateInObjectKey:
|
||||||
|
return "StateInObjectKey"
|
||||||
|
case StateInStructuredKey:
|
||||||
|
return "StateInStructuredKey"
|
||||||
|
case StateInStructuredValue:
|
||||||
|
return "StateInStructuredValue"
|
||||||
|
case StateNewline:
|
||||||
|
return "StateNewline"
|
||||||
|
case StateTab:
|
||||||
|
return "StateTab"
|
||||||
|
case StateSpace:
|
||||||
|
return "StateSpace"
|
||||||
|
case StateInString:
|
||||||
|
return "StateInString"
|
||||||
|
case StateInInt:
|
||||||
|
return "StateInInt"
|
||||||
|
case StateInFloat:
|
||||||
|
return "StateInFloat"
|
||||||
|
case StateInBool:
|
||||||
|
return "StateInBool"
|
||||||
|
case StateInNull:
|
||||||
|
return "StateInNull"
|
||||||
|
case StateInColon:
|
||||||
|
return "StateInColon"
|
||||||
|
case StateInComma:
|
||||||
|
return "StateInComma"
|
||||||
|
case StateInTab:
|
||||||
|
return "StateInTab"
|
||||||
|
case StateInSpaceToValue:
|
||||||
|
return "StateInSpaceToValue"
|
||||||
|
case StateInSpaceEndValue:
|
||||||
|
return "StateInSpaceEndValue"
|
||||||
|
case StateInNewlineEndValue:
|
||||||
|
return "StateInNewlineEndValue"
|
||||||
|
case StateInObjSpace:
|
||||||
|
return "StateInObjSpace"
|
||||||
|
case StateInList:
|
||||||
|
return "StateInList"
|
||||||
|
case StateInListComma:
|
||||||
|
return "StateInListComma"
|
||||||
|
case StateInValue:
|
||||||
|
return "StateInValue"
|
||||||
|
case StateInValueEnd:
|
||||||
|
return "StateInValueEnd"
|
||||||
|
case StateInListEnd:
|
||||||
|
return "StateInListEnd"
|
||||||
|
case StateInListObjectEnd:
|
||||||
|
return "StateInListObjectEnd"
|
||||||
|
case StateInNewline:
|
||||||
|
return "StateInNewline"
|
||||||
|
case StateInNumber:
|
||||||
|
return "StateInNumber"
|
||||||
|
case StateInNumberEnd:
|
||||||
|
return "StateInNumberEnd"
|
||||||
|
case StateInStringEnd:
|
||||||
|
return "StateInStringEnd"
|
||||||
|
case StateInObjectKeyEnd:
|
||||||
|
return "StateInObjectKeyEnd"
|
||||||
|
case StateTerminate:
|
||||||
|
return "StateTerminate"
|
||||||
|
case StateInObjectEnd:
|
||||||
|
return "StateInObjectEnd"
|
||||||
|
case StateTransitioningToTerminate:
|
||||||
|
return "StateTransitioningToTerminate"
|
||||||
|
case StateInListStartJSON:
|
||||||
|
return "StateInListStartJSON"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Unknown state: %d", s)
|
||||||
|
}
|
||||||
|
}
|
327
sample/pushdown_automata.go
Normal file
327
sample/pushdown_automata.go
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
Key JSON rules to consider:
|
||||||
|
|
||||||
|
1. Whitespace handling:
|
||||||
|
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
||||||
|
- Current code only handles some whitespace cases
|
||||||
|
|
||||||
|
2. Number validation:
|
||||||
|
- Need proper validation for special number cases like -0
|
||||||
|
- Should handle .5 style decimals
|
||||||
|
- Need limits on scientific notation (e, E)
|
||||||
|
|
||||||
|
3. String escaping:
|
||||||
|
- Currently marks \ as invalid but should allow escaped sequences:
|
||||||
|
- \"
|
||||||
|
- \n
|
||||||
|
- \u1234 unicode escapes
|
||||||
|
|
||||||
|
4. Empty object/array transitions:
|
||||||
|
- Direct {} and [] cases could be more explicit
|
||||||
|
- Need clear transitions for these edge cases
|
||||||
|
|
||||||
|
5. Nested depth limits:
|
||||||
|
- No protection against excessive nesting
|
||||||
|
- Could cause stack overflow with deeply nested structures
|
||||||
|
*/
|
||||||
|
|
||||||
|
// TODO: / should be valid but an escape character
|
||||||
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||||
|
|
||||||
|
var (
|
||||||
|
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
|
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
|
)
|
||||||
|
|
||||||
|
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||||
|
|
||||||
|
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||||
|
|
||||||
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||||
|
|
||||||
|
type PDA struct {
|
||||||
|
State JSONState
|
||||||
|
TransitionEdges map[rune]*PDA
|
||||||
|
MaskTokenIDToNode map[int32]*PDA
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPDANode(state JSONState) *PDA {
|
||||||
|
return &PDA{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PDAGraphBuilder struct {
|
||||||
|
proc model.TextProcessor
|
||||||
|
decodedToks []string
|
||||||
|
stateToNodeMap map[JSONState]*PDA
|
||||||
|
tokenToStatesMap map[int32][]JSONState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||||
|
stateToNodeMap := make(map[JSONState]*PDA)
|
||||||
|
for _, state := range JSONStates {
|
||||||
|
stateToNodeMap[state] = NewPDANode(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
|
// TODO: update naming here - and revisit values
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
|
// new line
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
|
||||||
|
// new line end value
|
||||||
|
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
// TODO: see if this is needed for formatting
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
|
// where values should be
|
||||||
|
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||||
|
|
||||||
|
// Leads to a value
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
// Values
|
||||||
|
// string node
|
||||||
|
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
|
||||||
|
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||||
|
|
||||||
|
// String end node
|
||||||
|
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
|
// number node
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// list node
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
// early end
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
|
// list end node
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// empty list
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||||
|
|
||||||
|
// null node
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// list comma
|
||||||
|
// should point to values
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
|
||||||
|
|
||||||
|
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||||
|
|
||||||
|
// list object end
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
// TODO: not sure if this is needed
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// bool node
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||||
|
}
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// comma node
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// todo: review this space transition
|
||||||
|
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
// space end value
|
||||||
|
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
b.stateToNodeMap = stateToNodeMap
|
||||||
|
if err := b.preComputeValidStates(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
|
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
}
|
||||||
|
|
||||||
|
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
|
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||||
|
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||||
|
for _, node := range b.stateToNodeMap {
|
||||||
|
// if node.State == StateInObjectKey {
|
||||||
|
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
|
||||||
|
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
|
||||||
|
// fmt.Println("copying string mask to object key mask")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
if err := b.CreateMask(node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
|
||||||
|
// TODO: make can be somewhere else too
|
||||||
|
b.tokenToStatesMap = make(map[int32][]JSONState)
|
||||||
|
for i, t := range b.decodedToks {
|
||||||
|
for _, r := range t {
|
||||||
|
if r == '"' {
|
||||||
|
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: the mask for obj key and string should be the same?
|
||||||
|
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
}
|
||||||
|
for i := range b.decodedToks {
|
||||||
|
token := b.decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
curNode := node
|
||||||
|
valid := true
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if curNode == nil || !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
// fmt.Println("next node", nextNode)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
264
sample/pushdown_runner.go
Normal file
264
sample/pushdown_runner.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: safety in case of invalid json
|
||||||
|
// TODO: partial JSON matching?
|
||||||
|
// TODO: interfaces to cleanup with return values
|
||||||
|
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||||
|
// TODO: add penalties for string \n stuff
|
||||||
|
// TODO: minimize number of fwd passes if there is only one match
|
||||||
|
// TODO: greedy sample initially and then backtrack if no match
|
||||||
|
|
||||||
|
type PushdownSampler struct {
|
||||||
|
PDAGraphBuilder
|
||||||
|
curNode *PDA
|
||||||
|
braceStack []rune
|
||||||
|
stateCounter uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// graph should be built once and reused per tokenizer
|
||||||
|
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
fmt.Println("PDA sampler")
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||||
|
|
||||||
|
vocab := proc.Vocab()
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
gb := &PDAGraphBuilder{
|
||||||
|
proc: proc,
|
||||||
|
decodedToks: decodedToks,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gb.BuildGraph(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
after := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||||
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||||
|
|
||||||
|
// TODO: this can be simplified
|
||||||
|
return &PushdownSampler{
|
||||||
|
curNode: gb.stateToNodeMap[StateStart],
|
||||||
|
PDAGraphBuilder: *gb,
|
||||||
|
braceStack: []rune{},
|
||||||
|
stateCounter: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: need to add resampling logic if the first sample was not good
|
||||||
|
// greedy sample + backtrack?
|
||||||
|
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
|
||||||
|
switch s.curNode.State {
|
||||||
|
case StateInString:
|
||||||
|
return s.maskLogits(logits, s.curNode)
|
||||||
|
|
||||||
|
case StateInListEnd:
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
case StateTerminate:
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
|
||||||
|
case StateInObjectEnd:
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
case StateInComma:
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
fmt.Println("masking logits current state", s.curNode.State)
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
|
fmt.Println("current state - updating", s.curNode.State)
|
||||||
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||||
|
|
||||||
|
// Special handling for EOS token in terminate state
|
||||||
|
if s.curNode.State == StateTerminate {
|
||||||
|
for _, tokenID := range tokenSlice {
|
||||||
|
if s.proc.Is(tokenID, model.SpecialEOS) {
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flag := -1
|
||||||
|
// endBraceRunes := []rune{'}', ']'}
|
||||||
|
for _, r := range mappedString {
|
||||||
|
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
||||||
|
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
||||||
|
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
||||||
|
// // flag = i
|
||||||
|
// break
|
||||||
|
|
||||||
|
// }
|
||||||
|
if r == rune('{') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
|
if r == rune('[') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
|
if r == rune('}') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if top != rune('{') {
|
||||||
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == rune(']') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if top != rune('[') {
|
||||||
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if flag != -1 {
|
||||||
|
// tokenSlice = tokenSlice[:flag]
|
||||||
|
// }
|
||||||
|
// fmt.Println("flag!", flag)
|
||||||
|
for _, tokenID := range tokenSlice {
|
||||||
|
// transition to the next node
|
||||||
|
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
||||||
|
}
|
||||||
|
fmt.Println("transitioning to", nextNode.State)
|
||||||
|
|
||||||
|
// TODO: add a penalty for staying in the same state too long
|
||||||
|
if nextNode.State == s.curNode.State {
|
||||||
|
s.stateCounter++
|
||||||
|
} else {
|
||||||
|
s.stateCounter = 0
|
||||||
|
}
|
||||||
|
s.curNode = nextNode
|
||||||
|
fmt.Println("updated curNode state", s.curNode.State)
|
||||||
|
}
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// greedy sample + backtrack?
|
||||||
|
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||||
|
// Create a new slice with same length as logits, initialized to -Inf
|
||||||
|
maskedLogits := make([]float32, len(logits))
|
||||||
|
for i := range maskedLogits {
|
||||||
|
maskedLogits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update values for valid token IDs from the mask map
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) {
|
||||||
|
maskedLogits[tokenID] = logits[tokenID]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskedLogits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||||
|
maxLogit := float32(math.Inf(-1))
|
||||||
|
maxIndex := -1
|
||||||
|
|
||||||
|
// Find the maximum logit value among valid tokens
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
||||||
|
maxLogit = logits[tokenID]
|
||||||
|
maxIndex = int(tokenID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxIndex == -1 {
|
||||||
|
return nil, fmt.Errorf("no valid tokens found in mask")
|
||||||
|
}
|
||||||
|
|
||||||
|
logits[0] = float32(maxIndex)
|
||||||
|
return logits, nil
|
||||||
|
// return maxIndex, nil
|
||||||
|
}
|
@ -23,6 +23,8 @@ type Sampler struct {
|
|||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
grammar *Grammar
|
grammar *Grammar
|
||||||
|
JSONSampler *JSONSampler
|
||||||
|
PythonSampler *PythonSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
|||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if s.JSONSampler != nil {
|
||||||
|
logits, err = s.JSONSampler.Apply(logits)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.PythonSampler != nil {
|
||||||
|
logits, err = s.PythonSampler.ApplyMask(logits)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
@ -161,6 +176,8 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||||||
minP: minP,
|
minP: minP,
|
||||||
temperature: temperature,
|
temperature: temperature,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
|
JSONSampler: jsonSampler,
|
||||||
|
PythonSampler: pythonSampler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
299
sample/structured_outputs.go
Normal file
299
sample/structured_outputs.go
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/jsonschema"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONSampler struct {
|
||||||
|
schema *jsonschema.Schema
|
||||||
|
propIdx int
|
||||||
|
propToNodeMap map[string]*PDA
|
||||||
|
pdaSampler *PushdownSampler
|
||||||
|
decodedToks []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
|
||||||
|
slog.Info("NewJSONSampler", "schema", schema)
|
||||||
|
if proc == nil {
|
||||||
|
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
pdaSampler, err := NewPushdownSampler(proc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema == nil {
|
||||||
|
return &JSONSampler{
|
||||||
|
schema: nil,
|
||||||
|
propIdx: -1,
|
||||||
|
propToNodeMap: nil,
|
||||||
|
pdaSampler: pdaSampler,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Println("schema not nil")
|
||||||
|
so := &JSONSampler{
|
||||||
|
schema: schema,
|
||||||
|
propIdx: -1,
|
||||||
|
propToNodeMap: make(map[string]*PDA),
|
||||||
|
pdaSampler: pdaSampler,
|
||||||
|
}
|
||||||
|
|
||||||
|
so.schemaToGraph()
|
||||||
|
|
||||||
|
// Benchmark token decoding
|
||||||
|
start := time.Now()
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before := m.Alloc
|
||||||
|
|
||||||
|
vocab := proc.Vocab()
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
so.decodedToks = decodedToks
|
||||||
|
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
after := m.Alloc
|
||||||
|
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
||||||
|
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
fmt.Println("SOSampler")
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
// Benchmark this section
|
||||||
|
start = time.Now()
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before = m.Alloc
|
||||||
|
|
||||||
|
// TODO: still messed up
|
||||||
|
// TODO: recursion use case
|
||||||
|
// key masks
|
||||||
|
for _, prop := range so.schema.Properties {
|
||||||
|
node := so.propToNodeMap[prop.Name]
|
||||||
|
// propName -> node
|
||||||
|
curState := node.State
|
||||||
|
fromNode := node
|
||||||
|
so.pdaSampler.CreateMask(fromNode)
|
||||||
|
for curState == StateInStructuredKey {
|
||||||
|
// there is only one edge
|
||||||
|
for r, toNode := range fromNode.TransitionEdges {
|
||||||
|
fmt.Println("rune", r, "edge", toNode.State)
|
||||||
|
so.pdaSampler.CreateMask(toNode)
|
||||||
|
fmt.Printf("created mask for %c\n", r)
|
||||||
|
curState = toNode.State
|
||||||
|
fmt.Println("next state", curState)
|
||||||
|
// TODO: theres an extra gen for " right now
|
||||||
|
fromNode = toNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if curState != StateInColon {
|
||||||
|
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// so.pdaSampler.CreateMask(fromNode)
|
||||||
|
|
||||||
|
fromNode = fromNode.TransitionEdges[' ']
|
||||||
|
|
||||||
|
so.pdaSampler.CreateMask(fromNode)
|
||||||
|
curState = fromNode.State
|
||||||
|
for _, toNode := range fromNode.TransitionEdges {
|
||||||
|
fmt.Println("toNode", toNode.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime.ReadMemStats(&m)
|
||||||
|
// after = m.Alloc
|
||||||
|
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||||
|
// fmt.Println("--------------------------------")
|
||||||
|
|
||||||
|
return so, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) schemaToGraph() {
|
||||||
|
schemaType := s.schema.EffectiveType()
|
||||||
|
switch schemaType {
|
||||||
|
case "object":
|
||||||
|
// TODO: see if we need to connect these to the JSON graph
|
||||||
|
|
||||||
|
// each prop is a key
|
||||||
|
for _, prop := range s.schema.Properties {
|
||||||
|
// name of key
|
||||||
|
name := prop.Name
|
||||||
|
keyNode := &PDA{
|
||||||
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
|
||||||
|
prevNode := keyNode
|
||||||
|
for _, r := range name {
|
||||||
|
runeNode := &PDA{
|
||||||
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
// fmt.Println("runeNode created", runeNode.State)
|
||||||
|
// fmt.Printf("runeNode created %c\n", r)
|
||||||
|
|
||||||
|
// since alloc on heap connections wil still map
|
||||||
|
prevNode.TransitionEdges[r] = runeNode
|
||||||
|
prevNode = runeNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// point to end of object key node after all chars are done
|
||||||
|
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||||
|
|
||||||
|
// link to value node
|
||||||
|
// Create a node for the end of the key (after the closing quote)
|
||||||
|
stringEndNode := &PDA{
|
||||||
|
State: StateInStructuredKey,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
prevNode = stringEndNode
|
||||||
|
|
||||||
|
// Add transition for colon after key
|
||||||
|
colonNode := &PDA{
|
||||||
|
State: StateInColon,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[':'] = colonNode
|
||||||
|
prevNode = colonNode
|
||||||
|
|
||||||
|
// Add transition for space after colon
|
||||||
|
spaceNode := &PDA{
|
||||||
|
State: StateInSpaceToValue,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[' '] = spaceNode
|
||||||
|
prevNode = spaceNode
|
||||||
|
|
||||||
|
value := prop.Type
|
||||||
|
switch value {
|
||||||
|
case "object":
|
||||||
|
fmt.Println("object under key: ", name)
|
||||||
|
case "array":
|
||||||
|
fmt.Println("array under key: ", name)
|
||||||
|
case "string":
|
||||||
|
fmt.Println("string under key: ", name)
|
||||||
|
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
||||||
|
case "number":
|
||||||
|
fmt.Println("number under key: ", name)
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
case "boolean":
|
||||||
|
fmt.Println("boolean under key: ", name)
|
||||||
|
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
|
// points to start of the key
|
||||||
|
s.propToNodeMap[name] = keyNode
|
||||||
|
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: do values + recursion
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
|
||||||
|
if s.schema == nil {
|
||||||
|
return s.pdaSampler.Apply(logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
// TODO: doesnt account for multi rune case
|
||||||
|
case StateInObjectKey:
|
||||||
|
if s.propIdx > len(s.schema.Properties)-1 {
|
||||||
|
return nil, fmt.Errorf("propIdx out of bounds")
|
||||||
|
}
|
||||||
|
// fmt.Println("in object key - structured outputs")
|
||||||
|
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||||
|
// simple case
|
||||||
|
s.propIdx++
|
||||||
|
fmt.Println("propIdx", s.propIdx)
|
||||||
|
prop := s.schema.Properties[s.propIdx]
|
||||||
|
fmt.Println("prop", prop.Name)
|
||||||
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
|
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||||
|
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
|
||||||
|
// Will only happen for the last prop - can also be precomputed.
|
||||||
|
if s.propIdx == len(s.schema.Properties)-1 {
|
||||||
|
// todo: if i incremenet propidx then i know im in last value as well
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
case StateInObjectEnd:
|
||||||
|
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
||||||
|
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
||||||
|
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||||
|
s.propIdx++
|
||||||
|
|
||||||
|
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
||||||
|
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||||
|
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||||
|
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||||
|
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
||||||
|
|
||||||
|
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
||||||
|
s.propIdx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.pdaSampler.Apply(logits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
|
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.schema == nil {
|
||||||
|
// Don't need to update state for unconstrained JSON sampling
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
case StateInObjectKey:
|
||||||
|
s.propIdx++
|
||||||
|
fmt.Println("propIdx", s.propIdx)
|
||||||
|
prop := s.schema.Properties[s.propIdx]
|
||||||
|
fmt.Println("prop", prop.Name)
|
||||||
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
|
// TODO: this does not work - mike
|
||||||
|
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// fmt.Println("str", str)
|
||||||
|
|
||||||
|
return tokenSlice, nil
|
||||||
|
default:
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
}
|
339
sample/structured_python.go
Normal file
339
sample/structured_python.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PythonState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PythonStateStart PythonState = iota
|
||||||
|
StateInFunction
|
||||||
|
StateInFunctionArgs
|
||||||
|
StateInFunctionArgsType
|
||||||
|
StateInFunctionEnd
|
||||||
|
PStateInString
|
||||||
|
PStateInStringEnd
|
||||||
|
PStateInNumber
|
||||||
|
PStateInList
|
||||||
|
PStateInListEnd
|
||||||
|
PStateInDict
|
||||||
|
PStateInDictEnd
|
||||||
|
PStateInTuple
|
||||||
|
PStateInTupleEnd
|
||||||
|
PStateTerminate
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s PythonState) String() string {
|
||||||
|
switch s {
|
||||||
|
case PythonStateStart:
|
||||||
|
return "PythonStateStart"
|
||||||
|
case StateInFunction:
|
||||||
|
return "StateInFunction"
|
||||||
|
case StateInFunctionArgs:
|
||||||
|
return "StateInFunctionArgs"
|
||||||
|
case StateInFunctionArgsType:
|
||||||
|
return "StateInFunctionArgsType"
|
||||||
|
case StateInFunctionEnd:
|
||||||
|
return "StateInFunctionEnd"
|
||||||
|
case PStateInString:
|
||||||
|
return "PStateInString"
|
||||||
|
case PStateInStringEnd:
|
||||||
|
return "PStateInStringEnd"
|
||||||
|
case PStateInNumber:
|
||||||
|
return "PStateInNumber"
|
||||||
|
case PStateInList:
|
||||||
|
return "PStateInList"
|
||||||
|
case PStateInListEnd:
|
||||||
|
return "PStateInListEnd"
|
||||||
|
case PStateInDict:
|
||||||
|
return "PStateInDict"
|
||||||
|
case PStateInDictEnd:
|
||||||
|
return "PStateInDictEnd"
|
||||||
|
case PStateInTuple:
|
||||||
|
return "PStateInTuple"
|
||||||
|
case PStateInTupleEnd:
|
||||||
|
return "PStateInTupleEnd"
|
||||||
|
case PStateTerminate:
|
||||||
|
return "PStateTerminate"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("PythonState(%d)", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var PythonStates = []PythonState{
|
||||||
|
PythonStateStart,
|
||||||
|
StateInFunction,
|
||||||
|
StateInFunctionArgs,
|
||||||
|
StateInFunctionArgsType,
|
||||||
|
StateInFunctionEnd,
|
||||||
|
PStateInString,
|
||||||
|
PStateInStringEnd,
|
||||||
|
PStateInNumber,
|
||||||
|
PStateInList,
|
||||||
|
PStateInListEnd,
|
||||||
|
PStateInDict,
|
||||||
|
PStateInDictEnd,
|
||||||
|
PStateInTuple,
|
||||||
|
PStateInTupleEnd,
|
||||||
|
PStateTerminate,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Node struct {
|
||||||
|
State PythonState
|
||||||
|
TransitionEdges map[rune]*Node
|
||||||
|
MaskTokenIDToNode map[int32]*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNode(state PythonState) *Node {
|
||||||
|
return &Node{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*Node),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PythonFunction struct {
|
||||||
|
Name string
|
||||||
|
Args []string
|
||||||
|
Types []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PythonSampler struct {
|
||||||
|
stateToNodes map[PythonState]*Node
|
||||||
|
proc model.TextProcessor
|
||||||
|
decodedToks []string
|
||||||
|
curNode *Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||||
|
s.proc = proc
|
||||||
|
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||||
|
for i := range proc.Vocab().Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
s.decodedToks = decodedToks
|
||||||
|
s.BuildGraph()
|
||||||
|
for _, function := range functions {
|
||||||
|
prevNode := s.stateToNodes[PythonStateStart]
|
||||||
|
|
||||||
|
for _, r := range function.Name {
|
||||||
|
nextNode := NewNode(StateInFunction)
|
||||||
|
prevNode.TransitionEdges[r] = nextNode
|
||||||
|
if err := s.CreateMask(nextNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("prevNode", prevNode.State)
|
||||||
|
fmt.Printf("transition edge: %q\n", r)
|
||||||
|
fmt.Println("nextNode", nextNode.State)
|
||||||
|
prevNode = nextNode
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
for i, arg := range function.Args {
|
||||||
|
for _, r := range arg {
|
||||||
|
nextNode := NewNode(StateInFunctionArgs)
|
||||||
|
prevNode.TransitionEdges[r] = nextNode
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = nextNode
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
// prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = prevNode.TransitionEdges['=']
|
||||||
|
switch function.Types[i] {
|
||||||
|
case "string":
|
||||||
|
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
|
||||||
|
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||||
|
case "number":
|
||||||
|
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
|
||||||
|
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
s.curNode = s.stateToNodes[PythonStateStart]
|
||||||
|
fmt.Println("curNode", s.curNode.State)
|
||||||
|
fmt.Println("transition edges", s.curNode.TransitionEdges)
|
||||||
|
if err := s.CreateMask(s.curNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
|
||||||
|
for tokenID, node := range s.curNode.MaskTokenIDToNode {
|
||||||
|
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) BuildGraph() error {
|
||||||
|
s.stateToNodes = make(map[PythonState]*Node)
|
||||||
|
for _, state := range PythonStates {
|
||||||
|
s.stateToNodes[state] = NewNode(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, state := range s.stateToNodes {
|
||||||
|
if err := s.CreateMask(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String
|
||||||
|
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
|
||||||
|
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
|
||||||
|
|
||||||
|
// String end
|
||||||
|
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
// Number
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||||
|
}
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
|
||||||
|
if s.curNode.State == PStateTerminate {
|
||||||
|
logits, err := finish(s, logits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) UpdateState(token int32) error {
|
||||||
|
mappedString, err := s.proc.Decode([]int32{token})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||||
|
|
||||||
|
if s.curNode.State == PStateTerminate {
|
||||||
|
if s.proc.Is(token, model.SpecialEOS) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
|
}
|
||||||
|
s.curNode = nextNode
|
||||||
|
fmt.Println("curNode", s.curNode.State)
|
||||||
|
for r, node := range s.curNode.TransitionEdges {
|
||||||
|
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
|
||||||
|
}
|
||||||
|
if err := s.CreateMask(s.curNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) CreateMask(node *Node) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
}
|
||||||
|
for i := range s.decodedToks {
|
||||||
|
token := s.decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
curNode := node
|
||||||
|
valid := true
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if curNode == nil || !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
if curNode.State == StateInFunction {
|
||||||
|
// fmt.Println("cm curNode", curNode.State)
|
||||||
|
// fmt.Println("cm token", s.decodedToks[i])
|
||||||
|
}
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
// fmt.Println("next node", nextNode)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
|
||||||
|
// Create a new slice with same length as logits, initialized to -Inf
|
||||||
|
maskedLogits := make([]float32, len(logits))
|
||||||
|
for i := range maskedLogits {
|
||||||
|
maskedLogits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update values for valid token IDs from the mask map
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) {
|
||||||
|
maskedLogits[tokenID] = logits[tokenID]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskedLogits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user