sample: wip structured outputs work
This commit is contained in:
parent
040e65abce
commit
e18540fecc
176
sample/state_machine.go
Normal file
176
sample/state_machine.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Node struct {
|
||||||
|
TransitionEdges map[rune]*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
type Graph struct {
|
||||||
|
proc model.TextProcessor
|
||||||
|
decodedToks []string
|
||||||
|
curNode *Node
|
||||||
|
grammar []byte
|
||||||
|
rules map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseRules is the set of rules that are used to parse the grammar
|
||||||
|
// JSON grammar from RFC 7159
|
||||||
|
var baseRules = map[string]string{
|
||||||
|
"object": "\"{\" (kv (\",\" kv)*)? \"}\"",
|
||||||
|
"array": "\"[\" (value (\",\" value)*)? \"]\"",
|
||||||
|
"string": "\"\\\"\" char* \"\\\"\"",
|
||||||
|
"number": "\"-\"? integer frac? exp?",
|
||||||
|
"kv": "string \":\" value",
|
||||||
|
"integer": "\"0\" | [1-9] [0-9]*",
|
||||||
|
"frac": "\".\" [0-9]+",
|
||||||
|
"exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+",
|
||||||
|
"escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]",
|
||||||
|
"char": "[^\"\\\\] | escape",
|
||||||
|
"space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*",
|
||||||
|
"hex": "[0-9] | [a-f] | [A-F]",
|
||||||
|
"boolean": "\"true\" | \"false\"",
|
||||||
|
"value": "object | array | string | number | boolean | \"null\"",
|
||||||
|
"null": "\"null\"",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Graph) BuildGraph(node *Node) error {
|
||||||
|
vocab := g.proc.Vocab()
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := g.proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
g.decodedToks = decodedToks
|
||||||
|
g.rules = baseRules
|
||||||
|
g.rootPrefixes()
|
||||||
|
rootNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
g.parseRule(g.rules["root"], rootNode)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rootPrefixes extracts all root prefixes from the grammar
|
||||||
|
// and parses the grammar string to extract root prefixes
|
||||||
|
func (g *Graph) rootPrefixes() {
|
||||||
|
lines := bytes.Split(g.grammar, []byte("\n"))
|
||||||
|
for _, line := range lines {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := bytes.SplitN(line, []byte("::="), 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleName := string(bytes.TrimSpace(parts[0]))
|
||||||
|
if strings.HasPrefix(ruleName, "root") {
|
||||||
|
g.rules[ruleName] = string(bytes.TrimSpace(parts[1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRule parses a grammar rule and returns a Node
|
||||||
|
func (g *Graph) parseRule(rule string, curNode *Node) *Node {
|
||||||
|
/*
|
||||||
|
Here are the special characters in BNF grammar and their functions:
|
||||||
|
::= - Definition operator, means "is defined as"
|
||||||
|
| - Alternation, means "or"
|
||||||
|
* - Zero or more repetitions of preceding element
|
||||||
|
+ - One or more repetitions
|
||||||
|
? - Optional (zero or one occurrence)
|
||||||
|
[] - Character class, matches any single character within brackets
|
||||||
|
[^] - Negated character class, matches any character NOT listed
|
||||||
|
() - Grouping of elements
|
||||||
|
- - Range operator in character classes (e.g., [a-z])
|
||||||
|
"" - Literal string match
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Split rule into tokens by whitespace
|
||||||
|
tokens := strings.Fields(rule)
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle integer rule
|
||||||
|
if strings.Contains(rule, "[0-9]+") {
|
||||||
|
// Create node for first digit 1-9
|
||||||
|
firstDigitNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
for r := '1'; r <= '9'; r++ {
|
||||||
|
curNode.TransitionEdges[r] = firstDigitNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create node for subsequent digits 0-9
|
||||||
|
zeroToNineNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
for r := '0'; r <= '9'; r++ {
|
||||||
|
// Loop back to same node for * operator
|
||||||
|
zeroToNineNode.TransitionEdges[r] = zeroToNineNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect first digit to subsequent digits
|
||||||
|
firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges
|
||||||
|
|
||||||
|
// Also handle the "0" case
|
||||||
|
if strings.Contains(rule, "\"0\"") {
|
||||||
|
zeroNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
curNode.TransitionEdges['0'] = zeroNode
|
||||||
|
}
|
||||||
|
|
||||||
|
return curNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// recursive case
|
||||||
|
// grammar options
|
||||||
|
// TODO: handle left recursion
|
||||||
|
if strings.Contains(rule, "|") {
|
||||||
|
parts := strings.Split(rule, "|")
|
||||||
|
savedNode := curNode
|
||||||
|
for _, part := range parts {
|
||||||
|
// TODO: add correct transitions
|
||||||
|
g.parseRule(part, savedNode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, token := range tokens {
|
||||||
|
if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") {
|
||||||
|
token = strings.Trim(token, "\"")
|
||||||
|
|
||||||
|
for _, r := range token {
|
||||||
|
newNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
curNode.TransitionEdges[r] = newNode
|
||||||
|
curNode = newNode
|
||||||
|
}
|
||||||
|
// strNode := &Node{
|
||||||
|
// TransitionEdges: make(map[rune]*Node),
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TODO: length constraint
|
||||||
|
// to self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return curNode
|
||||||
|
}
|
3
sample/structured_outputs.go
Normal file
3
sample/structured_outputs.go
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
type StructuredOutput struct{}
|
194
sample/structured_outputs_test.go
Normal file
194
sample/structured_outputs_test.go
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildGraph(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grammar []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty grammar",
|
||||||
|
grammar: []byte{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid grammar",
|
||||||
|
grammar: []byte(`root ::= value
|
||||||
|
value ::= string | number`),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
g := &Graph{
|
||||||
|
proc: &mockProcessor{},
|
||||||
|
grammar: tt.grammar,
|
||||||
|
rules: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := g.BuildGraph(node)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
if len(g.decodedToks) == 0 {
|
||||||
|
t.Error("Expected decoded tokens, got none")
|
||||||
|
}
|
||||||
|
if len(g.rules) == 0 {
|
||||||
|
t.Error("Expected rules to be populated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRootPrefixes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grammar []byte
|
||||||
|
expected map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty grammar",
|
||||||
|
grammar: []byte{},
|
||||||
|
expected: map[string]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "grammar with root prefix",
|
||||||
|
grammar: []byte(`root ::= value
|
||||||
|
root_string ::= string`),
|
||||||
|
expected: map[string]string{
|
||||||
|
"root": "value",
|
||||||
|
"root_string": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "grammar with comments and empty lines",
|
||||||
|
grammar: []byte(`# comment
|
||||||
|
root ::= value
|
||||||
|
|
||||||
|
# another comment
|
||||||
|
root_number ::= number`),
|
||||||
|
expected: map[string]string{
|
||||||
|
"root": "value",
|
||||||
|
"root_number": "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
g := &Graph{
|
||||||
|
grammar: tt.grammar,
|
||||||
|
rules: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
g.rootPrefixes()
|
||||||
|
|
||||||
|
for k, v := range tt.expected {
|
||||||
|
if actual, ok := g.rules[k]; !ok || actual != v {
|
||||||
|
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRule(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty rule",
|
||||||
|
rule: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple string",
|
||||||
|
rule: "root ::= \"test_string\"",
|
||||||
|
expected: "test_string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple string",
|
||||||
|
rule: "root ::= \"test_string\" | \"test_string2\"",
|
||||||
|
expected: "test_stringtest_string2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "integer",
|
||||||
|
rule: "root ::= [0-9]+",
|
||||||
|
// TODO: this is infinite acutally
|
||||||
|
expected: "0123456789",
|
||||||
|
},
|
||||||
|
// TODO: handle left recursion
|
||||||
|
// {
|
||||||
|
// name: "left recursion",
|
||||||
|
// rule: "root ::= root \"test_string\"",
|
||||||
|
// expected: "test_string",
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
g := &Graph{
|
||||||
|
rules: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
rootNode := &Node{
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
}
|
||||||
|
curNode := rootNode
|
||||||
|
g.parseRule(tt.rule, curNode)
|
||||||
|
sb := ""
|
||||||
|
for {
|
||||||
|
if len(curNode.TransitionEdges) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for r, n := range curNode.TransitionEdges {
|
||||||
|
sb += string(r)
|
||||||
|
curNode = n
|
||||||
|
}
|
||||||
|
t.Logf("sb: %s", sb)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sb != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, sb)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockProcessor implements the TextProcessor interface for testing
|
||||||
|
type mockProcessor struct{}
|
||||||
|
|
||||||
|
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
|
||||||
|
return "test", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProcessor) Vocab() *model.Vocabulary {
|
||||||
|
return &model.Vocabulary{
|
||||||
|
Values: []string{"test1", "test2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
return []int32{0, 1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProcessor) Is(token int32, special model.Special) bool {
|
||||||
|
return false
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user