Files
ollama/sample/structured_outputs_test.go
2025-03-27 11:26:49 -07:00

195 lines
3.7 KiB
Go

package sample
import (
"testing"
"github.com/ollama/ollama/model"
)
func TestBuildGraph(t *testing.T) {
tests := []struct {
name string
grammar []byte
wantErr bool
}{
{
name: "empty grammar",
grammar: []byte{},
wantErr: false,
},
{
name: "valid grammar",
grammar: []byte(`root ::= value
value ::= string | number`),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
proc: &mockProcessor{},
grammar: tt.grammar,
rules: make(map[string]string),
}
node := &Node{
TransitionEdges: make(map[rune]*Node),
}
err := g.BuildGraph(node)
if (err != nil) != tt.wantErr {
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
if len(g.decodedToks) == 0 {
t.Error("Expected decoded tokens, got none")
}
if len(g.rules) == 0 {
t.Error("Expected rules to be populated")
}
}
})
}
}
func TestRootPrefixes(t *testing.T) {
tests := []struct {
name string
grammar []byte
expected map[string]string
}{
{
name: "empty grammar",
grammar: []byte{},
expected: map[string]string{},
},
{
name: "grammar with root prefix",
grammar: []byte(`root ::= value
root_string ::= string`),
expected: map[string]string{
"root": "value",
"root_string": "string",
},
},
{
name: "grammar with comments and empty lines",
grammar: []byte(`# comment
root ::= value
# another comment
root_number ::= number`),
expected: map[string]string{
"root": "value",
"root_number": "number",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
grammar: tt.grammar,
rules: make(map[string]string),
}
g.rootPrefixes()
for k, v := range tt.expected {
if actual, ok := g.rules[k]; !ok || actual != v {
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
}
}
})
}
}
func TestParseRule(t *testing.T) {
tests := []struct {
name string
rule string
expected string
}{
{
name: "empty rule",
rule: "",
expected: "",
},
{
name: "simple string",
rule: "root ::= \"test_string\"",
expected: "test_string",
},
{
name: "simple string",
rule: "root ::= \"test_string\" | \"test_string2\"",
expected: "test_stringtest_string2",
},
{
name: "integer",
rule: "root ::= [0-9]+",
// TODO: this is infinite acutally
expected: "0123456789",
},
// TODO: handle left recursion
// {
// name: "left recursion",
// rule: "root ::= root \"test_string\"",
// expected: "test_string",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
rules: make(map[string]string),
}
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode := rootNode
g.parseRule(tt.rule, curNode)
sb := ""
for {
if len(curNode.TransitionEdges) == 0 {
break
}
for r, n := range curNode.TransitionEdges {
sb += string(r)
curNode = n
}
t.Logf("sb: %s", sb)
}
if sb != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, sb)
}
})
}
}
// mockProcessor implements the TextProcessor interface for testing
type mockProcessor struct{}
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
return "test", nil
}
func (m *mockProcessor) Vocab() *model.Vocabulary {
return &model.Vocabulary{
Values: []string{"test1", "test2"},
}
}
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
return []int32{0, 1}, nil
}
func (m *mockProcessor) Is(token int32, special model.Special) bool {
return false
}