Compare commits

...

3 Commits

Author SHA1 Message Date
ParthSareen
4450f871db wip 2025-03-25 16:45:27 -07:00
ParthSareen
5ec6bb52a0 prototyping 2025-03-25 15:00:14 -07:00
Blake Mizerany
1fd9967558 grammar: introduce new grammar package
This package provides a way to convert JSON schemas to equivalent EBNF.
It is intended to be a replacement to llama.cpp's schema_to_grammar.

This is still an early version and does not yet support all JSON schema
features. The to-do list includes:

- minumum/maximum constraints on integer types
- minLength/maxLength constraints on string types
- defs and refs
2025-03-24 11:58:06 -07:00
32 changed files with 2974 additions and 17 deletions

22
grammar/bench_test.go Normal file
View File

@ -0,0 +1,22 @@
//go:build go1.24
package grammar
import "testing"
func BenchmarkFromSchema(b *testing.B) {
for tt := range testCases(b) {
b.Run("", func(b *testing.B) {
s := []byte(tt.schema)
b.ReportAllocs()
for b.Loop() {
_, err := FromSchema(nil, s)
if err != nil {
b.Fatalf("GrammarFromSchema: %v", err)
}
}
})
return
}
}

227
grammar/grammar.go Normal file
View File

@ -0,0 +1,227 @@
package grammar
import (
"bytes"
"encoding/json"
"fmt"
"iter"
"strconv"
"github.com/ollama/ollama/grammar/jsonschema"
)
const jsonTerms = `
# Unicode
#
# Unicode characters can be specified directly in the grammar, for example
# hiragana ::= [-], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
# (\UXXXXXXXX).
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
# JSON grammar from RFC 7159
null ::= "null"
object ::= "{" (kv ("," kv)*)? "}"
array ::= "[" (value ("," value)*)? "]"
kv ::= string ":" value
integer ::= "0" | [1-9] [0-9]*
number ::= "-"? integer frac? exp?
frac ::= "." [0-9]+
exp ::= ("e" | "E") ("+" | "-") [0-9]+
string ::= "\"" char* "\""
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
char ::= [^"\\] | escape
space ::= (" " | "\t" | "\n" | "\r")*
hex ::= [0-9] | [a-f] | [A-F]
boolean ::= "true" | "false"
value ::= object | array | string | number | boolean | "null"
# User-defined
`
// FromSchema generates a grammar from a JSON schema.
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
var s *jsonschema.Schema
if err := json.Unmarshal(jsonSchema, &s); err != nil {
return nil, err
}
var g builder
// "root" is the only rule that is guaranteed to exist, so we start
// with its length for padding, and then adjust it as we go.
g.pad = len("root")
for id := range dependencies("root", s) {
g.pad = max(g.pad, len(id))
}
g.b.WriteString(jsonTerms)
ids := make(map[*jsonschema.Schema]string)
for id, s := range dependencies("root", s) {
ids[s] = id
g.define(id)
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
}
g.define("root")
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
g.define("") // finalize the last rule
return g.b.Bytes(), nil
}
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
switch typ := s.EffectiveType(); typ {
case "array":
if len(s.PrefixItems) == 0 && s.Items == nil {
g.u("array")
} else {
g.q("[")
for i, s := range s.PrefixItems {
if i > 0 {
g.q(",")
}
g.u(ids[s])
}
if s.Items != nil {
g.u("(")
if len(s.PrefixItems) > 0 {
g.q(",")
}
g.u(ids[s.Items])
g.u(")*")
}
g.q("]")
}
case "object":
if len(s.Properties) == 0 {
g.u("object")
} else {
g.q("{")
for i, p := range s.Properties {
name := ids[p]
if i > 0 {
g.q(",")
}
g.q(p.Name)
g.q(":")
g.u(name)
}
g.q("}")
}
case "number":
buildConstrainedNumber(g, s)
case "string":
if len(s.Enum) == 0 {
g.u("string")
} else {
g.u("(")
for i, e := range s.Enum {
if i > 0 {
g.q("|")
}
g.q(string(e))
}
g.u(")")
}
case "boolean", "value", "null", "integer":
g.u(typ)
default:
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
}
return nil
}
// dependencies returns a sequence of all child dependencies of the schema in
// post-order.
//
// The first value is the id/pointer to the dependency, and the second value
// is the schema.
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
return func(yield func(string, *jsonschema.Schema) bool) {
for i, p := range s.Properties {
id := fmt.Sprintf("%s_%d", id, i)
for did, d := range dependencies(id, p) {
if !yield(did, d) {
return
}
}
if !yield(id, p) {
return
}
}
for i, p := range s.PrefixItems {
id := fmt.Sprintf("tuple_%d", i)
for did, d := range dependencies(id, p) {
id := fmt.Sprintf("%s_%s", id, did)
if !yield(id, d) {
return
}
}
if !yield(id, p) {
return
}
}
if s.Items != nil {
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
for did, d := range dependencies(id, s.Items) {
if !yield(did, d) {
return
}
}
if !yield(id, s.Items) {
return
}
}
}
}
type builder struct {
b bytes.Buffer
pad int
rules int
items int
}
// define terminates the current rule, if any, and then either starts a new
// rule or does nothing else if the name is empty.
func (b *builder) define(name string) {
if b.rules > 0 {
b.b.WriteString(";\n")
}
if name == "" {
return
}
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
b.b.WriteString(" ::=")
b.rules++
b.items = 0
}
// quote appends a terminal to the current rule.
func (b *builder) q(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(strconv.Quote(s))
}
// u appends a non-terminal to the current rule.
func (b *builder) u(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(s)
}
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
if s.Minimum == 0 && s.Maximum == 0 {
b.u("TODO")
} else {
b.u("number")
}
}

75
grammar/grammar_test.go Normal file
View File

@ -0,0 +1,75 @@
package grammar
import (
"bufio"
"cmp"
"iter"
"strings"
"testing"
_ "embed"
"github.com/ollama/ollama/grammar/internal/diff"
)
func TestFromSchema(t *testing.T) {
for tt := range testCases(t) {
t.Run(tt.name, func(t *testing.T) {
g, err := FromSchema(nil, []byte(tt.schema))
if err != nil {
t.Fatalf("FromSchema: %v", err)
}
got := string(g)
got = strings.TrimPrefix(got, jsonTerms)
if got != tt.want {
t.Logf("schema:\n%s", tt.schema)
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
}
})
}
}
type testCase struct {
name string
schema string
want string
}
//go:embed testdata/schemas.txt
var tests string
func testCases(t testing.TB) iter.Seq[testCase] {
t.Helper()
return func(yield func(testCase) bool) {
t.Helper()
sc := bufio.NewScanner(strings.NewReader(tests))
name := ""
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" {
name = ""
continue
}
if line[0] == '#' {
name = cmp.Or(name, strings.TrimSpace(line[1:]))
continue
}
s := sc.Text()
g := ""
for sc.Scan() {
line = strings.TrimSpace(sc.Text())
if line == "" || line[0] == '#' {
break
}
g += sc.Text() + "\n"
}
if !yield(testCase{name, s, g}) {
return
}
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
}
if err := sc.Err(); err != nil {
t.Fatalf("error reading tests: %v", err)
}
}
}

View File

@ -0,0 +1,261 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"fmt"
"sort"
"strings"
)
// A pair is a pair of values tracked for both the x and y side of a diff.
// It is typically a pair of line indexes.
type pair struct{ x, y int }
// Diff returns an anchored diff of the two texts old and new
// in the “unified diff” format. If old and new are identical,
// Diff returns a nil slice (no output).
//
// Unix diff implementations typically look for a diff with
// the smallest number of lines inserted and removed,
// which can in the worst case take time quadratic in the
// number of lines in the texts. As a result, many implementations
// either can be made to run for a long time or cut off the search
// after a predetermined amount of work.
//
// In contrast, this implementation looks for a diff with the
// smallest number of “unique” lines inserted and removed,
// where unique means a line that appears just once in both old and new.
// We call this an “anchored diff” because the unique lines anchor
// the chosen matching regions. An anchored diff is usually clearer
// than a standard diff, because the algorithm does not try to
// reuse unrelated blank lines or closing braces.
// The algorithm also guarantees to run in O(n log n) time
// instead of the standard O(n²) time.
//
// Some systems call this approach a “patience diff,” named for
// the “patience sorting” algorithm, itself named for a solitaire card game.
// We avoid that name for two reasons. First, the name has been used
// for a few different variants of the algorithm, so it is imprecise.
// Second, the name is frequently interpreted as meaning that you have
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
// when in fact the algorithm is faster than the standard one.
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
if bytes.Equal(old, new) {
return nil
}
x := lines(old)
y := lines(new)
// Print diff header.
var out bytes.Buffer
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
fmt.Fprintf(&out, "--- %s\n", oldName)
fmt.Fprintf(&out, "+++ %s\n", newName)
// Loop over matches to consider,
// expanding each match to include surrounding lines,
// and then printing diff chunks.
// To avoid setup/teardown cases outside the loop,
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
// in the sequence of matches.
var (
done pair // printed up to x[:done.x] and y[:done.y]
chunk pair // start lines of current chunk
count pair // number of lines from each side in current chunk
ctext []string // lines for current chunk
)
for _, m := range tgs(x, y) {
if m.x < done.x {
// Already handled scanning forward from earlier match.
continue
}
// Expand matching lines as far as possible,
// establishing that x[start.x:end.x] == y[start.y:end.y].
// Note that on the first (or last) iteration we may (or definitely do)
// have an empty match: start.x==end.x and start.y==end.y.
start := m
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
start.x--
start.y--
}
end := m
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
end.x++
end.y++
}
// Emit the mismatched lines before start into this chunk.
// (No effect on first sentinel iteration, when start = {0,0}.)
for _, s := range x[done.x:start.x] {
ctext = append(ctext, "-"+s)
count.x++
}
for _, s := range y[done.y:start.y] {
ctext = append(ctext, "+"+s)
count.y++
}
// If we're not at EOF and have too few common lines,
// the chunk includes all the common lines and continues.
const C = 3 // number of context lines
if (end.x < len(x) || end.y < len(y)) &&
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
for _, s := range x[start.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
continue
}
// End chunk with common lines for context.
if len(ctext) > 0 {
n := end.x - start.x
if n > C {
n = C
}
for _, s := range x[start.x : start.x+n] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = pair{start.x + n, start.y + n}
// Format and emit chunk.
// Convert line numbers to 1-indexed.
// Special case: empty file shows up as 0,0 not 1,0.
if count.x > 0 {
chunk.x++
}
if count.y > 0 {
chunk.y++
}
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
for _, s := range ctext {
out.WriteString(s)
}
count.x = 0
count.y = 0
ctext = ctext[:0]
}
// If we reached EOF, we're done.
if end.x >= len(x) && end.y >= len(y) {
break
}
// Otherwise start a new chunk.
chunk = pair{end.x - C, end.y - C}
for _, s := range x[chunk.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
}
return out.Bytes()
}
// lines returns the lines in the file x, including newlines.
// If the file does not end in a newline, one is supplied
// along with a warning about the missing newline.
func lines(x []byte) []string {
l := strings.SplitAfter(string(x), "\n")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
} else {
// Treat last line as having a message about the missing newline attached,
// using the same text as BSD/GNU diff (including the leading backslash).
l[len(l)-1] += "\n\\ No newline at end of file\n"
}
return l
}
// tgs returns the pairs of indexes of the longest common subsequence
// of unique lines in x and y, where a unique line is one that appears
// once in x and once in y.
//
// The longest common subsequence algorithm is as described in
// Thomas G. Szymanski, “A Special Case of the Maximal Common
// Subsequence Problem,” Princeton TR #170 (January 1975),
// available at https://research.swtch.com/tgs170.pdf.
func tgs(x, y []string) []pair {
// Count the number of times each string appears in a and b.
// We only care about 0, 1, many, counted as 0, -1, -2
// for the x side and 0, -4, -8 for the y side.
// Using negative numbers now lets us distinguish positive line numbers later.
m := make(map[string]int)
for _, s := range x {
if c := m[s]; c > -2 {
m[s] = c - 1
}
}
for _, s := range y {
if c := m[s]; c > -8 {
m[s] = c - 4
}
}
// Now unique strings can be identified by m[s] = -1+-4.
//
// Gather the indexes of those strings in x and y, building:
// xi[i] = increasing indexes of unique strings in x.
// yi[i] = increasing indexes of unique strings in y.
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
var xi, yi, inv []int
for i, s := range y {
if m[s] == -1+-4 {
m[s] = len(yi)
yi = append(yi, i)
}
}
for i, s := range x {
if j, ok := m[s]; ok && j >= 0 {
xi = append(xi, i)
inv = append(inv, j)
}
}
// Apply Algorithm A from Szymanski's paper.
// In those terms, A = J = inv and B = [0, n).
// We add sentinel pairs {0,0}, and {len(x),len(y)}
// to the returned sequence, to help the processing loop.
J := inv
n := len(xi)
T := make([]int, n)
L := make([]int, n)
for i := range T {
T[i] = n + 1
}
for i := range n {
k := sort.Search(n, func(k int) bool {
return T[k] >= J[i]
})
T[k] = J[i]
L[i] = k + 1
}
k := 0
for _, v := range L {
if k < v {
k = v
}
}
seq := make([]pair, 2+k)
seq[1+k] = pair{len(x), len(y)} // sentinel at end
lastj := n
for i := n - 1; i >= 0; i-- {
if L[i] == k && J[i] < lastj {
seq[k] = pair{xi[i], yi[J[i]]}
k--
}
}
seq[0] = pair{0, 0} // sentinel at start
return seq
}

View File

@ -0,0 +1,44 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"path/filepath"
"testing"
"golang.org/x/tools/txtar"
)
func clean(text []byte) []byte {
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
text = bytes.TrimSuffix(text, []byte("^D\n"))
return text
}
func Test(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txt")
if len(files) == 0 {
t.Fatalf("no testdata")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatal(err)
}
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
t.Fatalf("%s: want three files, third named \"diff\"", file)
}
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
want := clean(a.Files[2].Data)
if !bytes.Equal(diffs, want) {
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
diffs, want, Diff("have", diffs, "want", want))
}
})
}
}

View File

@ -0,0 +1,13 @@
-- old --
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -0,0 +1,3 @@
+a
+b
+c

View File

@ -0,0 +1,13 @@
-- old --
a
b
c
-- new --
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +0,0 @@
-a
-b
-c

View File

@ -0,0 +1,35 @@
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
-- old --
a
b
c
d
e
f
g
-- new --
w
a
b
x
y
z
e
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,7 @@
+w
a
b
-c
-d
+x
+y
+z
e
-f
-g

40
grammar/internal/diff/testdata/dups.txt vendored Normal file
View File

@ -0,0 +1,40 @@
-- old --
a
b
c
d
e
f
-- new --
a
B
C
d
e
f
-- diff --
diff old new
--- old
+++ new
@@ -1,8 +1,8 @@
a
$
-b
-
-c
+B
+
+C
$
d
$

38
grammar/internal/diff/testdata/end.txt vendored Normal file
View File

@ -0,0 +1,38 @@
-- old --
1
2
3
4
5
6
7
eight
nine
ten
eleven
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -5,7 +5,6 @@
5
6
7
-eight
-nine
-ten
-eleven
+8
+9
+10

View File

@ -0,0 +1,9 @@
-- old --
a
b
c^D
-- new --
a
b
c^D
-- diff --

18
grammar/internal/diff/testdata/eof1.txt vendored Normal file
View File

@ -0,0 +1,18 @@
-- old --
a
b
c
-- new --
a
b
c^D
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
+c
\ No newline at end of file

18
grammar/internal/diff/testdata/eof2.txt vendored Normal file
View File

@ -0,0 +1,18 @@
-- old --
a
b
c^D
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
\ No newline at end of file
+c

62
grammar/internal/diff/testdata/long.txt vendored Normal file
View File

@ -0,0 +1,62 @@
-- old --
1
2
3
4
5
6
7
8
9
10
11
12
13
14
14½
15
16
17
18
19
20
-- new --
1
2
3
4
5
6
8
9
10
11
12
13
14
17
18
19
20
-- diff --
diff old new
--- old
+++ new
@@ -4,7 +4,6 @@
4
5
6
-7
8
9
10
@@ -12,9 +11,6 @@
12
13
14
-14½
-15
-16
17
18
19

View File

@ -0,0 +1,5 @@
-- old --
hello world
-- new --
hello world
-- diff --

View File

@ -0,0 +1,34 @@
-- old --
e
pi
4
5
6
7
8
9
10
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -1,5 +1,6 @@
-e
-pi
+1
+2
+3
4
5
6

40
grammar/internal/diff/testdata/triv.txt vendored Normal file
View File

@ -0,0 +1,40 @@
Another example from Hunt and McIlroy,
“An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
Anchored diff gives up on finding anything,
since there are no unique lines.
-- old --
a
b
c
a
b
b
a
-- new --
c
a
b
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,6 @@
-a
-b
-c
-a
-b
-b
-a
+c
+a
+b
+a
+b
+c

View File

@ -0,0 +1,171 @@
package jsonschema
import (
"bytes"
"encoding/json"
"errors"
)
// Schema holds a JSON schema.
type Schema struct {
// Name is the name of the property. For the parent/root property, this
// is "root". For child properties, this is the name of the property.
Name string `json:"-"`
// Type is the type of the property.
//
// TODO: Union types (e.g. make this a []string).
Type string
// PrefixItems is a list of schemas for each item in a tuple. By
// default, the tuple is "closed." unless Items is set to true or a
// valid Schema.
PrefixItems []*Schema
// Items is the schema for each item in a list.
//
// If it is missing, or its JSON value is "null" or "false", it is nil.
// If the JSON value is "true", it is set to the empty Schema. If the
// JSON value is an object, it will be decoded as a Schema.
Items *Schema
// MinItems specifies the minimum number of items allowed in a list.
MinItems int
// MaxItems specifies the maximum number of items allowed in a list.
MaxItems int
// Properties is the schema for each property of an object.
Properties []*Schema
// Format is the format of the property. This is used to validate the
// property against a specific format.
//
// It is the callers responsibility to validate the property against
// the format.
Format string
// Minimum specifies the minimum value for numeric properties.
Minimum float64
// Maximum specifies the maximum value for numeric properties.
Maximum float64
// Enum is a list of valid values for the property.
Enum []json.RawMessage
}
func (s *Schema) UnmarshalJSON(data []byte) error {
type S Schema
w := struct {
Properties props
Items items
*S
}{
S: (*S)(s),
}
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Items.set {
s.Items = &w.Items.Schema
}
s.Properties = w.Properties
return nil
}
type items struct {
Schema
set bool
}
func (s *items) UnmarshalJSON(data []byte) error {
switch b := data[0]; b {
case 't':
*s = items{set: true}
case '{':
type I items
if err := json.Unmarshal(data, (*I)(s)); err != nil {
return err
}
s.set = true
case 'n', 'f':
default:
return errors.New("invalid Items")
}
return nil
}
// EffectiveType returns the effective type of the schema. If the Type field is
// not empty, it is returned; otherwise:
//
// - If the schema has both Properties and Items, it returns an empty string.
// - If the schema has Properties, it returns "object".
// - If the schema has Items, it returns "array".
// - If the schema has neither Properties nor Items, it returns "value".
//
// The returned string is never empty.
func (d *Schema) EffectiveType() string {
if d.Type == "" {
if len(d.Properties) > 0 {
return "object"
}
if len(d.PrefixItems) > 0 || d.Items != nil {
return "array"
}
return "value"
}
return d.Type
}
// props is an ordered list of properties. The order of the properties
// is the order in which they were defined in the schema.
type props []*Schema
var _ json.Unmarshaler = (*props)(nil)
func (v *props) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] != '{' {
return errors.New("expected object")
}
d := json.NewDecoder(bytes.NewReader(data))
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
// llama.cpp, ignore unknown fields, which could be lead to unexpected
// behavior for clients of this package, since they may not be aware
// that "additionalFields", "itemsPrefix", etc, are being ignored.
//
// For now, just do what llama.cpp does.
t, err := d.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("expected object")
}
for d.More() {
// Use the first token (map key) as the property name, then
// decode the rest of the object fields into a Schema and
// append.
t, err := d.Token()
if err != nil {
return err
}
if t == json.Delim('}') {
return nil
}
s := &Schema{
Name: t.(string),
}
if err := d.Decode(s); err != nil {
return err
}
*v = append(*v, s)
}
return nil
}

View File

@ -0,0 +1,104 @@
package jsonschema
import (
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
const testSchemaBasic = `
{
"properties": {
"tupleClosedEmpty": { "prefixItems": [] },
"tupleClosedMissing": { "prefixItems": [{}] },
"tupleClosedNull": { "prefixItems": [{}], "items": null },
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
"array": { "items": {"type": "number"} },
"null": { "type": "null" },
"string": { "type": "string" },
"boolean": { "type": "boolean" }
}
}
`
func TestSchemaUnmarshal(t *testing.T) {
var got *Schema
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
want := &Schema{
Properties: []*Schema{
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
{Name: "array", Items: &Schema{Type: "number"}},
{Name: "null", Type: "null"},
{Name: "string", Type: "string"},
{Name: "boolean", Type: "boolean"},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want, +got)\n%s", diff)
}
}
func TestEffectiveType(t *testing.T) {
const schema = `
{"properties": {
"o": {"type": "object"},
"a": {"type": "array"},
"n": {"type": "number"},
"s": {"type": "string"},
"z": {"type": "null"},
"b": {"type": "boolean"},
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
"t1": {"items": {"type": "number"}, "maxItems": 3},
"v": {"maxItems": 3}
}}
`
var s *Schema
if err := json.Unmarshal([]byte(schema), &s); err != nil {
t.Fatalf("json.Unmarshal: %v", err)
}
var got []string
for _, p := range s.Properties {
got = append(got, p.EffectiveType())
}
want := strings.Fields(`
object
array
number
string
null
boolean
array
array
value
`)
if !reflect.DeepEqual(want, got) {
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
}
}

76
grammar/testdata/schemas.txt vendored Normal file
View File

@ -0,0 +1,76 @@
# This file holds tests for JSON schema to EBNF grammar conversions.
#
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
# name the tests number in the file (e.g. "#0", "#1", etc.)
#
# Blank lines signify the end or start of a new test. Comments can be added
# anywhere in the file, but they must be preceded by a '#' character and start at
# the beginning of the line.
# default
{}
root ::= value;
{"properties": {}}
root ::= value;
# array
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
root_0_tuple_0 ::= string;
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root ::= "{" "a" ":" root_0 "}";
# array with nested array
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
root_tuple_0_tuple_0 ::= string;
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
root ::= "[" ( root_tuple_0 )* "]";
# object
{"properties": {"e": {}}}
root_0 ::= value;
root ::= "{" "e" ":" root_0 "}";
# object with nested object
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
root_0_0 ::= value;
root_0 ::= "{" "e" ":" root_0_0 "}";
root ::= "{" "o" ":" root_0 "}";
# boolean
{"type": "boolean"}
root ::= boolean;
# number
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
root_0 ::= number;
root ::= "{" "n" ":" root_0 "}";
# string
{"type": "string"}
root ::= string;
# string with enum
{"type": "string", "enum": ["a", "b", "c"]}
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
# spaces in key
{"properties": {"a b": {}}}
root_0 ::= value;
root ::= "{" "a b" ":" root_0 "}";
# issue7978
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
root_0_tuple_0_0 ::= string;
root_0_tuple_0_1 ::= string;
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root_1 ::= string;
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
# !! # special characters in key
# !! {"properties": {"a!b": {}}}
# !! !invalid character '!' in key
# !!

View File

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/grammar"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
@ -700,9 +701,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
g, err := grammar.FromSchema(nil, req.Format)
if err != nil {
return fmt.Errorf("invalid JSON schema in format: %w", err)
}
req.Grammar = string(g)
}
@ -713,6 +714,11 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
req.Options = &opts
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@ -727,7 +733,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {

View File

@ -32,6 +32,7 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocab() *Vocabulary
}
type Vocabulary struct {

View File

@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Vocab() *Vocabulary {
return spm.vocab
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {

View File

@ -468,6 +468,20 @@ func (s *Server) processBatch() error {
return fmt.Errorf("failed to sample token: %w", err)
}
if seq.sampler.JSONSampler != nil {
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
if seq.sampler.PythonSampler != nil {
err = seq.sampler.PythonSampler.UpdateState(token)
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
@ -562,6 +576,21 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
// if err != nil {
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
// return
// }
// jsonSampler = nil
pythonSampler := &sample.PythonSampler{}
functions := []sample.PythonFunction{
{
Name: "add_two_strings",
Args: []string{"s1", "s2"},
Types: []string{"string", "string"},
},
}
pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
@ -569,6 +598,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.MinP,
req.Options.Seed,
grammar,
nil,
pythonSampler,
// nil,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

53
sample/gtf.go Normal file
View File

@ -0,0 +1,53 @@
package sample
var DefaultGrammar = map[string]string{
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
"null": `"null"`,
"object": `"{" (kv ("," kv)*)? "}"`,
"array": `"[" (value ("," value)*)? "]"`,
"kv": `string ":" value`,
"integer": `"0" | [1-9] [0-9]*`,
"number": `"-"? integer frac? exp?`,
"frac": `"." [0-9]+`,
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
"string": `"\"" char* "\""`,
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
"char": `[^"\\] | escape`,
"space": `(" " | "\t" | "\n" | "\r")*`,
"hex": `[0-9] | [a-f] | [A-F]`,
"boolean": `"true" | "false"`,
"value": `object | array | string | number | boolean | "null"`,
}
const jsonString = `object | array`
type StateMachine struct {
states map[rune]State
}
type State struct {
NextStates []string
// bitmask?
Mask []bool
IsTerminal bool
}
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
states := make(map[rune]State)
var cumu string
flag := false
for _, r := range startRule {
if r == '"' {
flag = !flag
}
if flag {
cumu += string(r)
}
}
sm := &StateMachine{
states: states,
}
return sm
}

138
sample/gtf_test.go Normal file
View File

@ -0,0 +1,138 @@
package sample
import (
"testing"
)
func TestGrammarParsing(t *testing.T) {
tests := []struct {
name string
grammar map[string]string
startRule string
input string
want bool
}{
{
name: "simple object",
grammar: map[string]string{
"object": `"{" "}"`,
},
startRule: "object",
input: "{}",
want: true,
},
{
name: "simple array",
grammar: map[string]string{
"array": `"[" "]"`,
},
startRule: "array",
input: "[]",
want: true,
},
{
name: "character class",
grammar: map[string]string{
"digit": `[0-9]`,
},
startRule: "digit",
input: "5",
want: true,
},
{
name: "alternation",
grammar: map[string]string{
"bool": `"true" | "false"`,
},
startRule: "bool",
input: "true",
want: true,
},
{
name: "repetition",
grammar: map[string]string{
"digits": `[0-9]+`,
},
startRule: "digits",
input: "123",
want: true,
},
{
name: "nested rules",
grammar: map[string]string{
"value": `object | array`,
"object": `"{" "}"`,
"array": `"[" "]"`,
},
startRule: "value",
input: "{}",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.grammar)
machine, err := parser.Parse(tt.startRule)
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
got, err := matcher.Match(tt.input)
if err != nil {
t.Fatalf("Match() error = %v", err)
}
if got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestJSONGrammar(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{"empty object", "{}", true},
{"empty array", "[]", true},
{"simple string", `"hello"`, true},
{"simple number", "123", true},
{"simple boolean", "true", true},
{"simple null", "null", true},
{"object with string", `{"key": "value"}`, true},
{"array with numbers", "[1, 2, 3]", true},
{"nested object", `{"obj": {"key": "value"}}`, true},
{"nested array", `[1, [2, 3], 4]`, true},
{"invalid object", "{", false},
{"invalid array", "[1, 2", false},
{"invalid string", `"hello`, false},
}
parser := NewParser(DefaultGrammar)
machine, err := parser.Parse("value")
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := matcher.Match(tt.input)
if tt.want {
if err != nil {
t.Errorf("Match() error = %v", err)
}
if !got {
t.Errorf("Match() = false, want true")
}
} else {
if err == nil && got {
t.Errorf("Match() = true, want false")
}
}
})
}
}

160
sample/json_types.go Normal file
View File

@ -0,0 +1,160 @@
package sample
import (
"fmt"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateInStructuredKey
StateInStructuredValue
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInColon
StateInComma
StateInTab
StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace
StateInList
StateInListComma
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
StateInListStartJSON
)
var JSONStates = []JSONState{
StateStart,
StateInObject,
StateInObjectKey,
StateInStructuredKey,
StateInStructuredValue,
StateNewline,
StateTab,
StateSpace,
StateInString,
StateInInt,
StateInFloat,
StateInBool,
StateInNull,
StateInColon,
StateInComma,
StateInTab,
StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace,
StateInListStartJSON,
StateInList,
StateInListComma,
StateInValue,
StateInValueEnd,
StateInListEnd,
StateInListObjectEnd,
StateInNewline,
StateInNumber,
StateInNumberEnd,
StateInStringEnd,
StateInObjectKeyEnd,
StateTerminate,
StateInObjectEnd,
StateTransitioningToTerminate,
}
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInStructuredKey:
return "StateInStructuredKey"
case StateInStructuredValue:
return "StateInStructuredValue"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpaceToValue:
return "StateInSpaceToValue"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateInValue:
return "StateInValue"
case StateInValueEnd:
return "StateInValueEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateTransitioningToTerminate:
return "StateTransitioningToTerminate"
case StateInListStartJSON:
return "StateInListStartJSON"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}

327
sample/pushdown_automata.go Normal file
View File

@ -0,0 +1,327 @@
package sample
import (
"fmt"
"slices"
"github.com/ollama/ollama/model"
)
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var (
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
)
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDA struct {
State JSONState
TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDA
}
func NewPDANode(state JSONState) *PDA {
return &PDA{
State: state,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
}
type PDAGraphBuilder struct {
proc model.TextProcessor
decodedToks []string
stateToNodeMap map[JSONState]*PDA
tokenToStatesMap map[int32][]JSONState
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state)
}
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
// TODO: update naming here - and revisit values
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
// new line end value
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
// Leads to a value
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
// Values
// string node
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// early end
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node
for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
}
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma
// should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// todo: review this space transition
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
// space end value
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
}
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
// TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
}
func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap {
// if node.State == StateInObjectKey {
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
// fmt.Println("copying string mask to object key mask")
// }
// }
if err := b.CreateMask(node); err != nil {
return err
}
}
return nil
}
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
// TODO: make can be somewhere else too
b.tokenToStatesMap = make(map[int32][]JSONState)
for i, t := range b.decodedToks {
for _, r := range t {
if r == '"' {
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
}
}
}
return nil
}
// TODO: the mask for obj key and string should be the same?
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range b.decodedToks {
token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}

264
sample/pushdown_runner.go Normal file
View File

@ -0,0 +1,264 @@
package sample
import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
// TODO: safety in case of invalid json
// TODO: partial JSON matching?
// TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
// TODO: minimize number of fwd passes if there is only one match
// TODO: greedy sample initially and then backtrack if no match
type PushdownSampler struct {
PDAGraphBuilder
curNode *PDA
braceStack []rune
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now()
fmt.Println("--------------------------------")
fmt.Println("PDA sampler")
fmt.Println("--------------------------------")
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
gb := &PDAGraphBuilder{
proc: proc,
decodedToks: decodedToks,
}
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{
curNode: gb.stateToNodeMap[StateStart],
PDAGraphBuilder: *gb,
braceStack: []rune{},
stateCounter: 0,
}, nil
}
// TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
}
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return nil, err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
// Special handling for EOS token in terminate state
if s.curNode.State == StateTerminate {
for _, tokenID := range tokenSlice {
if s.proc.Is(tokenID, model.SpecialEOS) {
return tokenSlice, nil
}
}
}
// flag := -1
// endBraceRunes := []rune{'}', ']'}
for _, r := range mappedString {
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
// // flag = i
// break
// }
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('{') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
if r == rune(']') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('[') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
}
// if flag != -1 {
// tokenSlice = tokenSlice[:flag]
// }
// fmt.Println("flag!", flag)
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return nil, fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = nextNode
fmt.Println("updated curNode state", s.curNode.State)
}
return tokenSlice, nil
}
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
maxLogit := float32(math.Inf(-1))
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float32(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@ -17,12 +17,14 @@ type token struct {
}
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
JSONSampler *JSONSampler
PythonSampler *PythonSampler
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
return -1, errors.New("sample: no logits provided to sample")
}
var err error
if s.JSONSampler != nil {
logits, err = s.JSONSampler.Apply(logits)
if err != nil {
return -1, err
}
}
if s.PythonSampler != nil {
logits, err = s.PythonSampler.ApplyMask(logits)
if err != nil {
return -1, err
}
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
JSONSampler: jsonSampler,
PythonSampler: pythonSampler,
}
}

View File

@ -0,0 +1,299 @@
package sample
import (
"fmt"
"log/slog"
"runtime"
"time"
"github.com/ollama/ollama/grammar/jsonschema"
"github.com/ollama/ollama/model"
)
type JSONSampler struct {
schema *jsonschema.Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
slog.Info("NewJSONSampler", "schema", schema)
if proc == nil {
return nil, fmt.Errorf("TextProcessor cannot be nil")
}
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
}
if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
// fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema,
propIdx: -1,
propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler,
}
so.schemaToGraph()
// Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start = time.Now()
runtime.ReadMemStats(&m)
before = m.Alloc
// TODO: still messed up
// TODO: recursion use case
// key masks
for _, prop := range so.schema.Properties {
node := so.propToNodeMap[prop.Name]
// propName -> node
curState := node.State
fromNode := node
so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey {
// there is only one edge
for r, toNode := range fromNode.TransitionEdges {
fmt.Println("rune", r, "edge", toNode.State)
so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r)
curState = toNode.State
fmt.Println("next state", curState)
// TODO: theres an extra gen for " right now
fromNode = toNode
}
}
if curState != StateInColon {
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
}
// so.pdaSampler.CreateMask(fromNode)
fromNode = fromNode.TransitionEdges[' ']
so.pdaSampler.CreateMask(fromNode)
curState = fromNode.State
for _, toNode := range fromNode.TransitionEdges {
fmt.Println("toNode", toNode.State)
}
}
// runtime.ReadMemStats(&m)
// after = m.Alloc
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
// fmt.Println("--------------------------------")
return so, nil
}
func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
// TODO: see if we need to connect these to the JSON graph
// each prop is a key
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
// fmt.Println("runeNode created", runeNode.State)
// fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode
}
// point to end of object key node after all chars are done
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// link to value node
// Create a node for the end of the key (after the closing quote)
stringEndNode := &PDA{
State: StateInStructuredKey,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges['"'] = stringEndNode
prevNode = stringEndNode
// Add transition for colon after key
colonNode := &PDA{
State: StateInColon,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[':'] = colonNode
prevNode = colonNode
// Add transition for space after colon
spaceNode := &PDA{
State: StateInSpaceToValue,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[' '] = spaceNode
prevNode = spaceNode
value := prop.Type
switch value {
case "object":
fmt.Println("object under key: ", name)
case "array":
fmt.Println("array under key: ", name)
case "string":
fmt.Println("string under key: ", name)
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
case "number":
fmt.Println("number under key: ", name)
for _, r := range validNumberRunes {
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
}
case "boolean":
fmt.Println("boolean under key: ", name)
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
}
// points to start of the key
s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
// TODO: do values + recursion
}
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State {
// TODO: doesnt account for multi rune case
case StateInObjectKey:
if s.propIdx > len(s.schema.Properties)-1 {
return nil, fmt.Errorf("propIdx out of bounds")
}
// fmt.Println("in object key - structured outputs")
// TODO: this tracking should probably be coming from a stack to track nested objects
// simple case
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
// Will only happen for the last prop - can also be precomputed.
if s.propIdx == len(s.schema.Properties)-1 {
// todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State {
case StateInObjectEnd:
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++
}
}
return s.pdaSampler.Apply(logits)
}
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil {
return nil, err
}
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return tokenSlice, nil
}
switch s.pdaSampler.curNode.State {
case StateInObjectKey:
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
// TODO: this does not work - mike
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
// if err != nil {
// return nil, err
// }
// fmt.Println("str", str)
return tokenSlice, nil
default:
return tokenSlice, nil
}
}

352
sample/structured_python.go Normal file
View File

@ -0,0 +1,352 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
type PythonState int
const (
PythonStateStart PythonState = iota
StateInFunction
StateInFunctionArgs
StateInFunctionArgsType
StateInFunctionEnd
PStateInString
PStateInStringEnd
PStateInNumber
PStateInList
PStateInListEnd
PStateInDict
PStateInDictEnd
PStateInTuple
PStateInTupleEnd
PStateTerminate
)
func (s PythonState) String() string {
switch s {
case PythonStateStart:
return "PythonStateStart"
case StateInFunction:
return "StateInFunction"
case StateInFunctionArgs:
return "StateInFunctionArgs"
case StateInFunctionArgsType:
return "StateInFunctionArgsType"
case StateInFunctionEnd:
return "StateInFunctionEnd"
case PStateInString:
return "PStateInString"
case PStateInStringEnd:
return "PStateInStringEnd"
case PStateInNumber:
return "PStateInNumber"
case PStateInList:
return "PStateInList"
case PStateInListEnd:
return "PStateInListEnd"
case PStateInDict:
return "PStateInDict"
case PStateInDictEnd:
return "PStateInDictEnd"
case PStateInTuple:
return "PStateInTuple"
case PStateInTupleEnd:
return "PStateInTupleEnd"
case PStateTerminate:
return "PStateTerminate"
default:
return fmt.Sprintf("PythonState(%d)", s)
}
}
var PythonStates = []PythonState{
PythonStateStart,
StateInFunction,
StateInFunctionArgs,
StateInFunctionArgsType,
StateInFunctionEnd,
PStateInString,
PStateInStringEnd,
PStateInNumber,
PStateInList,
PStateInListEnd,
PStateInDict,
PStateInDictEnd,
PStateInTuple,
PStateInTupleEnd,
PStateTerminate,
}
type Node struct {
State PythonState
TransitionEdges map[rune]*Node
MaskTokenIDToNode map[int32]*Node
}
func NewNode(state PythonState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[rune]*Node),
MaskTokenIDToNode: make(map[int32]*Node),
}
}
type PythonFunction struct {
Name string
Args []string
Types []string
}
type PythonSampler struct {
stateToNodes map[PythonState]*Node
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
s.decodedToks = decodedToks
s.BuildGraph()
for _, function := range functions {
prevNode := s.stateToNodes[PythonStateStart]
for _, r := range function.Name {
nextNode := NewNode(StateInFunction)
prevNode.TransitionEdges[r] = nextNode
if err := s.CreateMask(nextNode); err != nil {
return err
}
fmt.Println("prevNode", prevNode.State)
fmt.Printf("transition edge: %q\n", r)
fmt.Println("nextNode", nextNode.State)
prevNode = nextNode
}
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
s.CreateMask(prevNode)
prevNode = s.stateToNodes[StateInFunctionArgs]
for i, arg := range function.Args {
for _, r := range arg {
nextNode := NewNode(StateInFunctionArgs)
prevNode.TransitionEdges[r] = nextNode
s.CreateMask(prevNode)
prevNode = nextNode
}
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// prevNode = s.stateToNodes[StateInFunctionArgs]
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
s.CreateMask(prevNode)
prevNode = prevNode.TransitionEdges['=']
switch function.Types[i] {
case "string":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
s.CreateMask(prevNode.TransitionEdges['"'])
case "number":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
s.CreateMask(prevNode.TransitionEdges['"'])
}
}
}
s.curNode = s.stateToNodes[PythonStateStart]
fmt.Println("curNode", s.curNode.State)
fmt.Println("transition edges", s.curNode.TransitionEdges)
if err := s.CreateMask(s.curNode); err != nil {
return err
}
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
for tokenID, node := range s.curNode.MaskTokenIDToNode {
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
}
return nil
}
func (s *PythonSampler) BuildGraph() error {
s.stateToNodes = make(map[PythonState]*Node)
for _, state := range PythonStates {
s.stateToNodes[state] = NewNode(state)
}
for _, state := range s.stateToNodes {
if err := s.CreateMask(state); err != nil {
return err
}
}
// String
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
}
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
return nil
}
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
if s.curNode.State == PStateTerminate {
logits, err := finish(s, logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
func (s *PythonSampler) UpdateState(token int32) error {
mappedString, err := s.proc.Decode([]int32{token})
if err != nil {
return err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
if s.curNode.State == PStateTerminate {
if s.proc.Is(token, model.SpecialEOS) {
return nil
}
}
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
}
if err := s.CreateMask(s.curNode); err != nil {
return err
}
return nil
}
func (s *PythonSampler) CreateMask(node *Node) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range s.decodedToks {
token := s.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
if curNode.State == StateInFunction {
// fmt.Println("cm curNode", curNode.State)
// fmt.Println("cm token", s.decodedToks[i])
}
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}