Compare commits
3 Commits
main
...
parth/samp
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4450f871db | ||
![]() |
5ec6bb52a0 | ||
![]() |
1fd9967558 |
22
grammar/bench_test.go
Normal file
22
grammar/bench_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
//go:build go1.24
|
||||
|
||||
package grammar
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkFromSchema(b *testing.B) {
|
||||
for tt := range testCases(b) {
|
||||
b.Run("", func(b *testing.B) {
|
||||
s := []byte(tt.schema)
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
_, err := FromSchema(nil, s)
|
||||
if err != nil {
|
||||
b.Fatalf("GrammarFromSchema: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
227
grammar/grammar.go
Normal file
227
grammar/grammar.go
Normal file
@ -0,0 +1,227 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/grammar/jsonschema"
|
||||
)
|
||||
|
||||
const jsonTerms = `
|
||||
# Unicode
|
||||
#
|
||||
# Unicode characters can be specified directly in the grammar, for example
|
||||
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
|
||||
# (\UXXXXXXXX).
|
||||
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
|
||||
|
||||
# JSON grammar from RFC 7159
|
||||
null ::= "null"
|
||||
object ::= "{" (kv ("," kv)*)? "}"
|
||||
array ::= "[" (value ("," value)*)? "]"
|
||||
kv ::= string ":" value
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
number ::= "-"? integer frac? exp?
|
||||
frac ::= "." [0-9]+
|
||||
exp ::= ("e" | "E") ("+" | "-") [0-9]+
|
||||
string ::= "\"" char* "\""
|
||||
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
|
||||
char ::= [^"\\] | escape
|
||||
space ::= (" " | "\t" | "\n" | "\r")*
|
||||
hex ::= [0-9] | [a-f] | [A-F]
|
||||
boolean ::= "true" | "false"
|
||||
value ::= object | array | string | number | boolean | "null"
|
||||
|
||||
# User-defined
|
||||
`
|
||||
|
||||
// FromSchema generates a grammar from a JSON schema.
|
||||
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
|
||||
var s *jsonschema.Schema
|
||||
if err := json.Unmarshal(jsonSchema, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g builder
|
||||
|
||||
// "root" is the only rule that is guaranteed to exist, so we start
|
||||
// with its length for padding, and then adjust it as we go.
|
||||
g.pad = len("root")
|
||||
for id := range dependencies("root", s) {
|
||||
g.pad = max(g.pad, len(id))
|
||||
}
|
||||
|
||||
g.b.WriteString(jsonTerms)
|
||||
|
||||
ids := make(map[*jsonschema.Schema]string)
|
||||
for id, s := range dependencies("root", s) {
|
||||
ids[s] = id
|
||||
g.define(id)
|
||||
if err := fromSchema(&g, ids, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
g.define("root")
|
||||
if err := fromSchema(&g, ids, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g.define("") // finalize the last rule
|
||||
return g.b.Bytes(), nil
|
||||
}
|
||||
|
||||
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
|
||||
switch typ := s.EffectiveType(); typ {
|
||||
case "array":
|
||||
if len(s.PrefixItems) == 0 && s.Items == nil {
|
||||
g.u("array")
|
||||
} else {
|
||||
g.q("[")
|
||||
for i, s := range s.PrefixItems {
|
||||
if i > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.u(ids[s])
|
||||
}
|
||||
if s.Items != nil {
|
||||
g.u("(")
|
||||
if len(s.PrefixItems) > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.u(ids[s.Items])
|
||||
g.u(")*")
|
||||
}
|
||||
g.q("]")
|
||||
}
|
||||
case "object":
|
||||
if len(s.Properties) == 0 {
|
||||
g.u("object")
|
||||
} else {
|
||||
g.q("{")
|
||||
for i, p := range s.Properties {
|
||||
name := ids[p]
|
||||
if i > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.q(p.Name)
|
||||
g.q(":")
|
||||
g.u(name)
|
||||
}
|
||||
g.q("}")
|
||||
}
|
||||
case "number":
|
||||
buildConstrainedNumber(g, s)
|
||||
case "string":
|
||||
if len(s.Enum) == 0 {
|
||||
g.u("string")
|
||||
} else {
|
||||
g.u("(")
|
||||
for i, e := range s.Enum {
|
||||
if i > 0 {
|
||||
g.q("|")
|
||||
}
|
||||
g.q(string(e))
|
||||
}
|
||||
g.u(")")
|
||||
}
|
||||
case "boolean", "value", "null", "integer":
|
||||
g.u(typ)
|
||||
default:
|
||||
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dependencies returns a sequence of all child dependencies of the schema in
|
||||
// post-order.
|
||||
//
|
||||
// The first value is the id/pointer to the dependency, and the second value
|
||||
// is the schema.
|
||||
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
|
||||
return func(yield func(string, *jsonschema.Schema) bool) {
|
||||
for i, p := range s.Properties {
|
||||
id := fmt.Sprintf("%s_%d", id, i)
|
||||
for did, d := range dependencies(id, p) {
|
||||
if !yield(did, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
for i, p := range s.PrefixItems {
|
||||
id := fmt.Sprintf("tuple_%d", i)
|
||||
for did, d := range dependencies(id, p) {
|
||||
id := fmt.Sprintf("%s_%s", id, did)
|
||||
if !yield(id, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.Items != nil {
|
||||
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
|
||||
for did, d := range dependencies(id, s.Items) {
|
||||
if !yield(did, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, s.Items) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type builder struct {
|
||||
b bytes.Buffer
|
||||
pad int
|
||||
rules int
|
||||
items int
|
||||
}
|
||||
|
||||
// define terminates the current rule, if any, and then either starts a new
|
||||
// rule or does nothing else if the name is empty.
|
||||
func (b *builder) define(name string) {
|
||||
if b.rules > 0 {
|
||||
b.b.WriteString(";\n")
|
||||
}
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
|
||||
b.b.WriteString(" ::=")
|
||||
b.rules++
|
||||
b.items = 0
|
||||
}
|
||||
|
||||
// quote appends a terminal to the current rule.
|
||||
func (b *builder) q(s string) {
|
||||
if b.items > 0 {
|
||||
b.b.WriteString(" ")
|
||||
}
|
||||
b.b.WriteString(" ")
|
||||
b.b.WriteString(strconv.Quote(s))
|
||||
}
|
||||
|
||||
// u appends a non-terminal to the current rule.
|
||||
func (b *builder) u(s string) {
|
||||
if b.items > 0 {
|
||||
b.b.WriteString(" ")
|
||||
}
|
||||
b.b.WriteString(" ")
|
||||
b.b.WriteString(s)
|
||||
}
|
||||
|
||||
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
|
||||
if s.Minimum == 0 && s.Maximum == 0 {
|
||||
b.u("TODO")
|
||||
} else {
|
||||
b.u("number")
|
||||
}
|
||||
}
|
75
grammar/grammar_test.go
Normal file
75
grammar/grammar_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"cmp"
|
||||
"iter"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ollama/ollama/grammar/internal/diff"
|
||||
)
|
||||
|
||||
func TestFromSchema(t *testing.T) {
|
||||
for tt := range testCases(t) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g, err := FromSchema(nil, []byte(tt.schema))
|
||||
if err != nil {
|
||||
t.Fatalf("FromSchema: %v", err)
|
||||
}
|
||||
got := string(g)
|
||||
got = strings.TrimPrefix(got, jsonTerms)
|
||||
if got != tt.want {
|
||||
t.Logf("schema:\n%s", tt.schema)
|
||||
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
schema string
|
||||
want string
|
||||
}
|
||||
|
||||
//go:embed testdata/schemas.txt
|
||||
var tests string
|
||||
|
||||
func testCases(t testing.TB) iter.Seq[testCase] {
|
||||
t.Helper()
|
||||
return func(yield func(testCase) bool) {
|
||||
t.Helper()
|
||||
sc := bufio.NewScanner(strings.NewReader(tests))
|
||||
name := ""
|
||||
for sc.Scan() {
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" {
|
||||
name = ""
|
||||
continue
|
||||
}
|
||||
if line[0] == '#' {
|
||||
name = cmp.Or(name, strings.TrimSpace(line[1:]))
|
||||
continue
|
||||
}
|
||||
s := sc.Text()
|
||||
g := ""
|
||||
for sc.Scan() {
|
||||
line = strings.TrimSpace(sc.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
break
|
||||
}
|
||||
g += sc.Text() + "\n"
|
||||
}
|
||||
if !yield(testCase{name, s, g}) {
|
||||
return
|
||||
}
|
||||
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
|
||||
}
|
||||
if err := sc.Err(); err != nil {
|
||||
t.Fatalf("error reading tests: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
261
grammar/internal/diff/diff.go
Normal file
261
grammar/internal/diff/diff.go
Normal file
@ -0,0 +1,261 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A pair is a pair of values tracked for both the x and y side of a diff.
|
||||
// It is typically a pair of line indexes.
|
||||
type pair struct{ x, y int }
|
||||
|
||||
// Diff returns an anchored diff of the two texts old and new
|
||||
// in the “unified diff” format. If old and new are identical,
|
||||
// Diff returns a nil slice (no output).
|
||||
//
|
||||
// Unix diff implementations typically look for a diff with
|
||||
// the smallest number of lines inserted and removed,
|
||||
// which can in the worst case take time quadratic in the
|
||||
// number of lines in the texts. As a result, many implementations
|
||||
// either can be made to run for a long time or cut off the search
|
||||
// after a predetermined amount of work.
|
||||
//
|
||||
// In contrast, this implementation looks for a diff with the
|
||||
// smallest number of “unique” lines inserted and removed,
|
||||
// where unique means a line that appears just once in both old and new.
|
||||
// We call this an “anchored diff” because the unique lines anchor
|
||||
// the chosen matching regions. An anchored diff is usually clearer
|
||||
// than a standard diff, because the algorithm does not try to
|
||||
// reuse unrelated blank lines or closing braces.
|
||||
// The algorithm also guarantees to run in O(n log n) time
|
||||
// instead of the standard O(n²) time.
|
||||
//
|
||||
// Some systems call this approach a “patience diff,” named for
|
||||
// the “patience sorting” algorithm, itself named for a solitaire card game.
|
||||
// We avoid that name for two reasons. First, the name has been used
|
||||
// for a few different variants of the algorithm, so it is imprecise.
|
||||
// Second, the name is frequently interpreted as meaning that you have
|
||||
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
|
||||
// when in fact the algorithm is faster than the standard one.
|
||||
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
|
||||
if bytes.Equal(old, new) {
|
||||
return nil
|
||||
}
|
||||
x := lines(old)
|
||||
y := lines(new)
|
||||
|
||||
// Print diff header.
|
||||
var out bytes.Buffer
|
||||
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
|
||||
fmt.Fprintf(&out, "--- %s\n", oldName)
|
||||
fmt.Fprintf(&out, "+++ %s\n", newName)
|
||||
|
||||
// Loop over matches to consider,
|
||||
// expanding each match to include surrounding lines,
|
||||
// and then printing diff chunks.
|
||||
// To avoid setup/teardown cases outside the loop,
|
||||
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
|
||||
// in the sequence of matches.
|
||||
var (
|
||||
done pair // printed up to x[:done.x] and y[:done.y]
|
||||
chunk pair // start lines of current chunk
|
||||
count pair // number of lines from each side in current chunk
|
||||
ctext []string // lines for current chunk
|
||||
)
|
||||
for _, m := range tgs(x, y) {
|
||||
if m.x < done.x {
|
||||
// Already handled scanning forward from earlier match.
|
||||
continue
|
||||
}
|
||||
|
||||
// Expand matching lines as far as possible,
|
||||
// establishing that x[start.x:end.x] == y[start.y:end.y].
|
||||
// Note that on the first (or last) iteration we may (or definitely do)
|
||||
// have an empty match: start.x==end.x and start.y==end.y.
|
||||
start := m
|
||||
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
|
||||
start.x--
|
||||
start.y--
|
||||
}
|
||||
end := m
|
||||
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
|
||||
end.x++
|
||||
end.y++
|
||||
}
|
||||
|
||||
// Emit the mismatched lines before start into this chunk.
|
||||
// (No effect on first sentinel iteration, when start = {0,0}.)
|
||||
for _, s := range x[done.x:start.x] {
|
||||
ctext = append(ctext, "-"+s)
|
||||
count.x++
|
||||
}
|
||||
for _, s := range y[done.y:start.y] {
|
||||
ctext = append(ctext, "+"+s)
|
||||
count.y++
|
||||
}
|
||||
|
||||
// If we're not at EOF and have too few common lines,
|
||||
// the chunk includes all the common lines and continues.
|
||||
const C = 3 // number of context lines
|
||||
if (end.x < len(x) || end.y < len(y)) &&
|
||||
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
|
||||
for _, s := range x[start.x:end.x] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = end
|
||||
continue
|
||||
}
|
||||
|
||||
// End chunk with common lines for context.
|
||||
if len(ctext) > 0 {
|
||||
n := end.x - start.x
|
||||
if n > C {
|
||||
n = C
|
||||
}
|
||||
for _, s := range x[start.x : start.x+n] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = pair{start.x + n, start.y + n}
|
||||
|
||||
// Format and emit chunk.
|
||||
// Convert line numbers to 1-indexed.
|
||||
// Special case: empty file shows up as 0,0 not 1,0.
|
||||
if count.x > 0 {
|
||||
chunk.x++
|
||||
}
|
||||
if count.y > 0 {
|
||||
chunk.y++
|
||||
}
|
||||
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
|
||||
for _, s := range ctext {
|
||||
out.WriteString(s)
|
||||
}
|
||||
count.x = 0
|
||||
count.y = 0
|
||||
ctext = ctext[:0]
|
||||
}
|
||||
|
||||
// If we reached EOF, we're done.
|
||||
if end.x >= len(x) && end.y >= len(y) {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise start a new chunk.
|
||||
chunk = pair{end.x - C, end.y - C}
|
||||
for _, s := range x[chunk.x:end.x] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = end
|
||||
}
|
||||
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
// lines returns the lines in the file x, including newlines.
|
||||
// If the file does not end in a newline, one is supplied
|
||||
// along with a warning about the missing newline.
|
||||
func lines(x []byte) []string {
|
||||
l := strings.SplitAfter(string(x), "\n")
|
||||
if l[len(l)-1] == "" {
|
||||
l = l[:len(l)-1]
|
||||
} else {
|
||||
// Treat last line as having a message about the missing newline attached,
|
||||
// using the same text as BSD/GNU diff (including the leading backslash).
|
||||
l[len(l)-1] += "\n\\ No newline at end of file\n"
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// tgs returns the pairs of indexes of the longest common subsequence
|
||||
// of unique lines in x and y, where a unique line is one that appears
|
||||
// once in x and once in y.
|
||||
//
|
||||
// The longest common subsequence algorithm is as described in
|
||||
// Thomas G. Szymanski, “A Special Case of the Maximal Common
|
||||
// Subsequence Problem,” Princeton TR #170 (January 1975),
|
||||
// available at https://research.swtch.com/tgs170.pdf.
|
||||
func tgs(x, y []string) []pair {
|
||||
// Count the number of times each string appears in a and b.
|
||||
// We only care about 0, 1, many, counted as 0, -1, -2
|
||||
// for the x side and 0, -4, -8 for the y side.
|
||||
// Using negative numbers now lets us distinguish positive line numbers later.
|
||||
m := make(map[string]int)
|
||||
for _, s := range x {
|
||||
if c := m[s]; c > -2 {
|
||||
m[s] = c - 1
|
||||
}
|
||||
}
|
||||
for _, s := range y {
|
||||
if c := m[s]; c > -8 {
|
||||
m[s] = c - 4
|
||||
}
|
||||
}
|
||||
|
||||
// Now unique strings can be identified by m[s] = -1+-4.
|
||||
//
|
||||
// Gather the indexes of those strings in x and y, building:
|
||||
// xi[i] = increasing indexes of unique strings in x.
|
||||
// yi[i] = increasing indexes of unique strings in y.
|
||||
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
|
||||
var xi, yi, inv []int
|
||||
for i, s := range y {
|
||||
if m[s] == -1+-4 {
|
||||
m[s] = len(yi)
|
||||
yi = append(yi, i)
|
||||
}
|
||||
}
|
||||
for i, s := range x {
|
||||
if j, ok := m[s]; ok && j >= 0 {
|
||||
xi = append(xi, i)
|
||||
inv = append(inv, j)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Algorithm A from Szymanski's paper.
|
||||
// In those terms, A = J = inv and B = [0, n).
|
||||
// We add sentinel pairs {0,0}, and {len(x),len(y)}
|
||||
// to the returned sequence, to help the processing loop.
|
||||
J := inv
|
||||
n := len(xi)
|
||||
T := make([]int, n)
|
||||
L := make([]int, n)
|
||||
for i := range T {
|
||||
T[i] = n + 1
|
||||
}
|
||||
for i := range n {
|
||||
k := sort.Search(n, func(k int) bool {
|
||||
return T[k] >= J[i]
|
||||
})
|
||||
T[k] = J[i]
|
||||
L[i] = k + 1
|
||||
}
|
||||
k := 0
|
||||
for _, v := range L {
|
||||
if k < v {
|
||||
k = v
|
||||
}
|
||||
}
|
||||
seq := make([]pair, 2+k)
|
||||
seq[1+k] = pair{len(x), len(y)} // sentinel at end
|
||||
lastj := n
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
if L[i] == k && J[i] < lastj {
|
||||
seq[k] = pair{xi[i], yi[J[i]]}
|
||||
k--
|
||||
}
|
||||
}
|
||||
seq[0] = pair{0, 0} // sentinel at start
|
||||
return seq
|
||||
}
|
44
grammar/internal/diff/diff_test.go
Normal file
44
grammar/internal/diff/diff_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/tools/txtar"
|
||||
)
|
||||
|
||||
func clean(text []byte) []byte {
|
||||
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
|
||||
text = bytes.TrimSuffix(text, []byte("^D\n"))
|
||||
return text
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
files, _ := filepath.Glob("testdata/*.txt")
|
||||
if len(files) == 0 {
|
||||
t.Fatalf("no testdata")
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||
a, err := txtar.ParseFile(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
|
||||
t.Fatalf("%s: want three files, third named \"diff\"", file)
|
||||
}
|
||||
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
|
||||
want := clean(a.Files[2].Data)
|
||||
if !bytes.Equal(diffs, want) {
|
||||
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
|
||||
diffs, want, Diff("have", diffs, "want", want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
-- old --
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -0,0 +1,3 @@
|
||||
+a
|
||||
+b
|
||||
+c
|
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- new --
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +0,0 @@
|
||||
-a
|
||||
-b
|
||||
-c
|
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
|
||||
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
-- new --
|
||||
w
|
||||
a
|
||||
b
|
||||
x
|
||||
y
|
||||
z
|
||||
e
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,7 +1,7 @@
|
||||
+w
|
||||
a
|
||||
b
|
||||
-c
|
||||
-d
|
||||
+x
|
||||
+y
|
||||
+z
|
||||
e
|
||||
-f
|
||||
-g
|
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
-- old --
|
||||
a
|
||||
|
||||
b
|
||||
|
||||
c
|
||||
|
||||
d
|
||||
|
||||
e
|
||||
|
||||
f
|
||||
-- new --
|
||||
a
|
||||
|
||||
B
|
||||
|
||||
C
|
||||
|
||||
d
|
||||
|
||||
e
|
||||
|
||||
f
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,8 +1,8 @@
|
||||
a
|
||||
$
|
||||
-b
|
||||
-
|
||||
-c
|
||||
+B
|
||||
+
|
||||
+C
|
||||
$
|
||||
d
|
||||
$
|
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
-- old --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
eleven
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -5,7 +5,6 @@
|
||||
5
|
||||
6
|
||||
7
|
||||
-eight
|
||||
-nine
|
||||
-ten
|
||||
-eleven
|
||||
+8
|
||||
+9
|
||||
+10
|
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- diff --
|
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +1,3 @@
|
||||
a
|
||||
b
|
||||
-c
|
||||
+c
|
||||
\ No newline at end of file
|
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +1,3 @@
|
||||
a
|
||||
b
|
||||
-c
|
||||
\ No newline at end of file
|
||||
+c
|
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
-- old --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
11
|
||||
12
|
||||
13
|
||||
14
|
||||
14½
|
||||
15
|
||||
16
|
||||
17
|
||||
18
|
||||
19
|
||||
20
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
8
|
||||
9
|
||||
10
|
||||
11
|
||||
12
|
||||
13
|
||||
14
|
||||
17
|
||||
18
|
||||
19
|
||||
20
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -4,7 +4,6 @@
|
||||
4
|
||||
5
|
||||
6
|
||||
-7
|
||||
8
|
||||
9
|
||||
10
|
||||
@@ -12,9 +11,6 @@
|
||||
12
|
||||
13
|
||||
14
|
||||
-14½
|
||||
-15
|
||||
-16
|
||||
17
|
||||
18
|
||||
19
|
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
-- old --
|
||||
hello world
|
||||
-- new --
|
||||
hello world
|
||||
-- diff --
|
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
-- old --
|
||||
e
|
||||
pi
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,5 +1,6 @@
|
||||
-e
|
||||
-pi
|
||||
+1
|
||||
+2
|
||||
+3
|
||||
4
|
||||
5
|
||||
6
|
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
Another example from Hunt and McIlroy,
|
||||
“An Algorithm for Differential File Comparison.”
|
||||
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||
|
||||
Anchored diff gives up on finding anything,
|
||||
since there are no unique lines.
|
||||
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
a
|
||||
b
|
||||
b
|
||||
a
|
||||
-- new --
|
||||
c
|
||||
a
|
||||
b
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,7 +1,6 @@
|
||||
-a
|
||||
-b
|
||||
-c
|
||||
-a
|
||||
-b
|
||||
-b
|
||||
-a
|
||||
+c
|
||||
+a
|
||||
+b
|
||||
+a
|
||||
+b
|
||||
+c
|
171
grammar/jsonschema/decode.go
Normal file
171
grammar/jsonschema/decode.go
Normal file
@ -0,0 +1,171 @@
|
||||
package jsonschema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Schema holds a JSON schema.
|
||||
type Schema struct {
|
||||
// Name is the name of the property. For the parent/root property, this
|
||||
// is "root". For child properties, this is the name of the property.
|
||||
Name string `json:"-"`
|
||||
|
||||
// Type is the type of the property.
|
||||
//
|
||||
// TODO: Union types (e.g. make this a []string).
|
||||
Type string
|
||||
|
||||
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||
// default, the tuple is "closed." unless Items is set to true or a
|
||||
// valid Schema.
|
||||
PrefixItems []*Schema
|
||||
|
||||
// Items is the schema for each item in a list.
|
||||
//
|
||||
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||
// JSON value is an object, it will be decoded as a Schema.
|
||||
Items *Schema
|
||||
|
||||
// MinItems specifies the minimum number of items allowed in a list.
|
||||
MinItems int
|
||||
|
||||
// MaxItems specifies the maximum number of items allowed in a list.
|
||||
MaxItems int
|
||||
|
||||
// Properties is the schema for each property of an object.
|
||||
Properties []*Schema
|
||||
|
||||
// Format is the format of the property. This is used to validate the
|
||||
// property against a specific format.
|
||||
//
|
||||
// It is the callers responsibility to validate the property against
|
||||
// the format.
|
||||
Format string
|
||||
|
||||
// Minimum specifies the minimum value for numeric properties.
|
||||
Minimum float64
|
||||
|
||||
// Maximum specifies the maximum value for numeric properties.
|
||||
Maximum float64
|
||||
|
||||
// Enum is a list of valid values for the property.
|
||||
Enum []json.RawMessage
|
||||
}
|
||||
|
||||
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||
type S Schema
|
||||
w := struct {
|
||||
Properties props
|
||||
Items items
|
||||
*S
|
||||
}{
|
||||
S: (*S)(s),
|
||||
}
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.Items.set {
|
||||
s.Items = &w.Items.Schema
|
||||
}
|
||||
s.Properties = w.Properties
|
||||
return nil
|
||||
}
|
||||
|
||||
type items struct {
|
||||
Schema
|
||||
set bool
|
||||
}
|
||||
|
||||
func (s *items) UnmarshalJSON(data []byte) error {
|
||||
switch b := data[0]; b {
|
||||
case 't':
|
||||
*s = items{set: true}
|
||||
case '{':
|
||||
type I items
|
||||
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.set = true
|
||||
case 'n', 'f':
|
||||
default:
|
||||
return errors.New("invalid Items")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||
// not empty, it is returned; otherwise:
|
||||
//
|
||||
// - If the schema has both Properties and Items, it returns an empty string.
|
||||
// - If the schema has Properties, it returns "object".
|
||||
// - If the schema has Items, it returns "array".
|
||||
// - If the schema has neither Properties nor Items, it returns "value".
|
||||
//
|
||||
// The returned string is never empty.
|
||||
func (d *Schema) EffectiveType() string {
|
||||
if d.Type == "" {
|
||||
if len(d.Properties) > 0 {
|
||||
return "object"
|
||||
}
|
||||
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||
return "array"
|
||||
}
|
||||
return "value"
|
||||
}
|
||||
return d.Type
|
||||
}
|
||||
|
||||
// props is an ordered list of properties. The order of the properties
|
||||
// is the order in which they were defined in the schema.
|
||||
type props []*Schema
|
||||
|
||||
var _ json.Unmarshaler = (*props)(nil)
|
||||
|
||||
func (v *props) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if data[0] != '{' {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewReader(data))
|
||||
|
||||
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||
// behavior for clients of this package, since they may not be aware
|
||||
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||
//
|
||||
// For now, just do what llama.cpp does.
|
||||
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t != json.Delim('{') {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
for d.More() {
|
||||
// Use the first token (map key) as the property name, then
|
||||
// decode the rest of the object fields into a Schema and
|
||||
// append.
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == json.Delim('}') {
|
||||
return nil
|
||||
}
|
||||
s := &Schema{
|
||||
Name: t.(string),
|
||||
}
|
||||
if err := d.Decode(s); err != nil {
|
||||
return err
|
||||
}
|
||||
*v = append(*v, s)
|
||||
}
|
||||
return nil
|
||||
}
|
104
grammar/jsonschema/decode_test.go
Normal file
104
grammar/jsonschema/decode_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package jsonschema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
const testSchemaBasic = `
|
||||
{
|
||||
"properties": {
|
||||
"tupleClosedEmpty": { "prefixItems": [] },
|
||||
"tupleClosedMissing": { "prefixItems": [{}] },
|
||||
"tupleClosedNull": { "prefixItems": [{}], "items": null },
|
||||
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
|
||||
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
|
||||
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
|
||||
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
|
||||
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
|
||||
|
||||
"array": { "items": {"type": "number"} },
|
||||
|
||||
"null": { "type": "null" },
|
||||
"string": { "type": "string" },
|
||||
"boolean": { "type": "boolean" }
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestSchemaUnmarshal(t *testing.T) {
|
||||
var got *Schema
|
||||
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
want := &Schema{
|
||||
Properties: []*Schema{
|
||||
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
|
||||
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
|
||||
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
|
||||
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
|
||||
|
||||
{Name: "array", Items: &Schema{Type: "number"}},
|
||||
|
||||
{Name: "null", Type: "null"},
|
||||
{Name: "string", Type: "string"},
|
||||
{Name: "boolean", Type: "boolean"},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("(-want, +got)\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveType(t *testing.T) {
|
||||
const schema = `
|
||||
{"properties": {
|
||||
"o": {"type": "object"},
|
||||
"a": {"type": "array"},
|
||||
"n": {"type": "number"},
|
||||
"s": {"type": "string"},
|
||||
"z": {"type": "null"},
|
||||
"b": {"type": "boolean"},
|
||||
|
||||
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
|
||||
"t1": {"items": {"type": "number"}, "maxItems": 3},
|
||||
|
||||
"v": {"maxItems": 3}
|
||||
}}
|
||||
`
|
||||
|
||||
var s *Schema
|
||||
if err := json.Unmarshal([]byte(schema), &s); err != nil {
|
||||
t.Fatalf("json.Unmarshal: %v", err)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for _, p := range s.Properties {
|
||||
got = append(got, p.EffectiveType())
|
||||
}
|
||||
|
||||
want := strings.Fields(`
|
||||
object
|
||||
array
|
||||
number
|
||||
string
|
||||
null
|
||||
boolean
|
||||
array
|
||||
array
|
||||
value
|
||||
`)
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
|
||||
}
|
||||
}
|
76
grammar/testdata/schemas.txt
vendored
Normal file
76
grammar/testdata/schemas.txt
vendored
Normal file
@ -0,0 +1,76 @@
|
||||
# This file holds tests for JSON schema to EBNF grammar conversions.
|
||||
#
|
||||
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
|
||||
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
|
||||
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
|
||||
# name the tests number in the file (e.g. "#0", "#1", etc.)
|
||||
#
|
||||
# Blank lines signify the end or start of a new test. Comments can be added
|
||||
# anywhere in the file, but they must be preceded by a '#' character and start at
|
||||
# the beginning of the line.
|
||||
|
||||
# default
|
||||
{}
|
||||
root ::= value;
|
||||
|
||||
{"properties": {}}
|
||||
root ::= value;
|
||||
|
||||
# array
|
||||
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
|
||||
root_0_tuple_0 ::= string;
|
||||
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||
root ::= "{" "a" ":" root_0 "}";
|
||||
|
||||
# array with nested array
|
||||
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
|
||||
root_tuple_0_tuple_0 ::= string;
|
||||
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
|
||||
root ::= "[" ( root_tuple_0 )* "]";
|
||||
|
||||
# object
|
||||
{"properties": {"e": {}}}
|
||||
root_0 ::= value;
|
||||
root ::= "{" "e" ":" root_0 "}";
|
||||
|
||||
# object with nested object
|
||||
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
|
||||
root_0_0 ::= value;
|
||||
root_0 ::= "{" "e" ":" root_0_0 "}";
|
||||
root ::= "{" "o" ":" root_0 "}";
|
||||
|
||||
# boolean
|
||||
{"type": "boolean"}
|
||||
root ::= boolean;
|
||||
|
||||
# number
|
||||
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
|
||||
root_0 ::= number;
|
||||
root ::= "{" "n" ":" root_0 "}";
|
||||
|
||||
# string
|
||||
{"type": "string"}
|
||||
root ::= string;
|
||||
|
||||
# string with enum
|
||||
{"type": "string", "enum": ["a", "b", "c"]}
|
||||
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
|
||||
|
||||
# spaces in key
|
||||
{"properties": {"a b": {}}}
|
||||
root_0 ::= value;
|
||||
root ::= "{" "a b" ":" root_0 "}";
|
||||
|
||||
# issue7978
|
||||
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
|
||||
root_0_tuple_0_0 ::= string;
|
||||
root_0_tuple_0_1 ::= string;
|
||||
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
|
||||
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||
root_1 ::= string;
|
||||
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
|
||||
|
||||
# !! # special characters in key
|
||||
# !! {"properties": {"a!b": {}}}
|
||||
# !! !invalid character '!' in key
|
||||
# !!
|
@ -29,6 +29,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/grammar"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@ -700,9 +701,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
g, err := grammar.FromSchema(nil, req.Format)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid JSON schema in format: %w", err)
|
||||
}
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
@ -713,6 +714,11 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
@ -727,7 +733,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||
}
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
|
@ -32,6 +32,7 @@ type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocab() *Vocabulary
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
|
@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Vocab() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||
|
@ -468,6 +468,20 @@ func (s *Server) processBatch() error {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
}
|
||||
|
||||
if seq.sampler.JSONSampler != nil {
|
||||
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if seq.sampler.PythonSampler != nil {
|
||||
err = seq.sampler.PythonSampler.UpdateState(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
@ -562,6 +576,21 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
|
||||
// if err != nil {
|
||||
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// jsonSampler = nil
|
||||
pythonSampler := &sample.PythonSampler{}
|
||||
functions := []sample.PythonFunction{
|
||||
{
|
||||
Name: "add_two_strings",
|
||||
Args: []string{"s1", "s2"},
|
||||
Types: []string{"string", "string"},
|
||||
},
|
||||
}
|
||||
pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||
sampler := sample.NewSampler(
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
@ -569,6 +598,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
nil,
|
||||
pythonSampler,
|
||||
// nil,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
|
53
sample/gtf.go
Normal file
53
sample/gtf.go
Normal file
@ -0,0 +1,53 @@
|
||||
package sample
|
||||
|
||||
var DefaultGrammar = map[string]string{
|
||||
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
|
||||
"null": `"null"`,
|
||||
"object": `"{" (kv ("," kv)*)? "}"`,
|
||||
"array": `"[" (value ("," value)*)? "]"`,
|
||||
"kv": `string ":" value`,
|
||||
"integer": `"0" | [1-9] [0-9]*`,
|
||||
"number": `"-"? integer frac? exp?`,
|
||||
"frac": `"." [0-9]+`,
|
||||
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
|
||||
"string": `"\"" char* "\""`,
|
||||
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
|
||||
"char": `[^"\\] | escape`,
|
||||
"space": `(" " | "\t" | "\n" | "\r")*`,
|
||||
"hex": `[0-9] | [a-f] | [A-F]`,
|
||||
"boolean": `"true" | "false"`,
|
||||
"value": `object | array | string | number | boolean | "null"`,
|
||||
}
|
||||
|
||||
const jsonString = `object | array`
|
||||
|
||||
type StateMachine struct {
|
||||
states map[rune]State
|
||||
}
|
||||
|
||||
type State struct {
|
||||
NextStates []string
|
||||
// bitmask?
|
||||
Mask []bool
|
||||
IsTerminal bool
|
||||
}
|
||||
|
||||
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
|
||||
states := make(map[rune]State)
|
||||
|
||||
var cumu string
|
||||
flag := false
|
||||
for _, r := range startRule {
|
||||
if r == '"' {
|
||||
flag = !flag
|
||||
}
|
||||
if flag {
|
||||
cumu += string(r)
|
||||
}
|
||||
}
|
||||
|
||||
sm := &StateMachine{
|
||||
states: states,
|
||||
}
|
||||
return sm
|
||||
}
|
138
sample/gtf_test.go
Normal file
138
sample/gtf_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGrammarParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grammar map[string]string
|
||||
startRule string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
grammar: map[string]string{
|
||||
"object": `"{" "}"`,
|
||||
},
|
||||
startRule: "object",
|
||||
input: "{}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "simple array",
|
||||
grammar: map[string]string{
|
||||
"array": `"[" "]"`,
|
||||
},
|
||||
startRule: "array",
|
||||
input: "[]",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "character class",
|
||||
grammar: map[string]string{
|
||||
"digit": `[0-9]`,
|
||||
},
|
||||
startRule: "digit",
|
||||
input: "5",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "alternation",
|
||||
grammar: map[string]string{
|
||||
"bool": `"true" | "false"`,
|
||||
},
|
||||
startRule: "bool",
|
||||
input: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "repetition",
|
||||
grammar: map[string]string{
|
||||
"digits": `[0-9]+`,
|
||||
},
|
||||
startRule: "digits",
|
||||
input: "123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nested rules",
|
||||
grammar: map[string]string{
|
||||
"value": `object | array`,
|
||||
"object": `"{" "}"`,
|
||||
"array": `"[" "]"`,
|
||||
},
|
||||
startRule: "value",
|
||||
input: "{}",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewParser(tt.grammar)
|
||||
machine, err := parser.Parse(tt.startRule)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
|
||||
matcher := NewMatcher(machine)
|
||||
got, err := matcher.Match(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Match() error = %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONGrammar(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"empty object", "{}", true},
|
||||
{"empty array", "[]", true},
|
||||
{"simple string", `"hello"`, true},
|
||||
{"simple number", "123", true},
|
||||
{"simple boolean", "true", true},
|
||||
{"simple null", "null", true},
|
||||
{"object with string", `{"key": "value"}`, true},
|
||||
{"array with numbers", "[1, 2, 3]", true},
|
||||
{"nested object", `{"obj": {"key": "value"}}`, true},
|
||||
{"nested array", `[1, [2, 3], 4]`, true},
|
||||
{"invalid object", "{", false},
|
||||
{"invalid array", "[1, 2", false},
|
||||
{"invalid string", `"hello`, false},
|
||||
}
|
||||
|
||||
parser := NewParser(DefaultGrammar)
|
||||
machine, err := parser.Parse("value")
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
|
||||
matcher := NewMatcher(machine)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := matcher.Match(tt.input)
|
||||
if tt.want {
|
||||
if err != nil {
|
||||
t.Errorf("Match() error = %v", err)
|
||||
}
|
||||
if !got {
|
||||
t.Errorf("Match() = false, want true")
|
||||
}
|
||||
} else {
|
||||
if err == nil && got {
|
||||
t.Errorf("Match() = true, want false")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
160
sample/json_types.go
Normal file
160
sample/json_types.go
Normal file
@ -0,0 +1,160 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
|
||||
const (
|
||||
StateStart JSONState = iota
|
||||
StateInObject
|
||||
StateInObjectKey
|
||||
StateInStructuredKey
|
||||
StateInStructuredValue
|
||||
StateNewline
|
||||
StateTab
|
||||
StateSpace
|
||||
StateInString
|
||||
StateInInt
|
||||
StateInFloat
|
||||
StateInBool
|
||||
StateInNull
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInTab
|
||||
StateInSpaceToValue
|
||||
StateInSpaceEndValue
|
||||
StateInNewlineEndValue
|
||||
StateInObjSpace
|
||||
StateInList
|
||||
StateInListComma
|
||||
StateInValue
|
||||
StateInValueEnd
|
||||
StateInListEnd
|
||||
StateInListObjectEnd
|
||||
StateInNewline
|
||||
StateInNumber
|
||||
StateInNumberEnd
|
||||
StateInStringEnd
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
StateInObjectEnd
|
||||
StateTransitioningToTerminate
|
||||
StateInListStartJSON
|
||||
)
|
||||
|
||||
var JSONStates = []JSONState{
|
||||
StateStart,
|
||||
StateInObject,
|
||||
StateInObjectKey,
|
||||
StateInStructuredKey,
|
||||
StateInStructuredValue,
|
||||
StateNewline,
|
||||
StateTab,
|
||||
StateSpace,
|
||||
StateInString,
|
||||
StateInInt,
|
||||
StateInFloat,
|
||||
StateInBool,
|
||||
StateInNull,
|
||||
StateInColon,
|
||||
StateInComma,
|
||||
StateInTab,
|
||||
StateInSpaceToValue,
|
||||
StateInSpaceEndValue,
|
||||
StateInNewlineEndValue,
|
||||
StateInObjSpace,
|
||||
StateInListStartJSON,
|
||||
StateInList,
|
||||
StateInListComma,
|
||||
StateInValue,
|
||||
StateInValueEnd,
|
||||
StateInListEnd,
|
||||
StateInListObjectEnd,
|
||||
StateInNewline,
|
||||
StateInNumber,
|
||||
StateInNumberEnd,
|
||||
StateInStringEnd,
|
||||
StateInObjectKeyEnd,
|
||||
StateTerminate,
|
||||
StateInObjectEnd,
|
||||
StateTransitioningToTerminate,
|
||||
}
|
||||
|
||||
func (s JSONState) String() string {
|
||||
switch s {
|
||||
case StateStart:
|
||||
return "StateStart"
|
||||
case StateInObject:
|
||||
return "StateInObject"
|
||||
case StateInObjectKey:
|
||||
return "StateInObjectKey"
|
||||
case StateInStructuredKey:
|
||||
return "StateInStructuredKey"
|
||||
case StateInStructuredValue:
|
||||
return "StateInStructuredValue"
|
||||
case StateNewline:
|
||||
return "StateNewline"
|
||||
case StateTab:
|
||||
return "StateTab"
|
||||
case StateSpace:
|
||||
return "StateSpace"
|
||||
case StateInString:
|
||||
return "StateInString"
|
||||
case StateInInt:
|
||||
return "StateInInt"
|
||||
case StateInFloat:
|
||||
return "StateInFloat"
|
||||
case StateInBool:
|
||||
return "StateInBool"
|
||||
case StateInNull:
|
||||
return "StateInNull"
|
||||
case StateInColon:
|
||||
return "StateInColon"
|
||||
case StateInComma:
|
||||
return "StateInComma"
|
||||
case StateInTab:
|
||||
return "StateInTab"
|
||||
case StateInSpaceToValue:
|
||||
return "StateInSpaceToValue"
|
||||
case StateInSpaceEndValue:
|
||||
return "StateInSpaceEndValue"
|
||||
case StateInNewlineEndValue:
|
||||
return "StateInNewlineEndValue"
|
||||
case StateInObjSpace:
|
||||
return "StateInObjSpace"
|
||||
case StateInList:
|
||||
return "StateInList"
|
||||
case StateInListComma:
|
||||
return "StateInListComma"
|
||||
case StateInValue:
|
||||
return "StateInValue"
|
||||
case StateInValueEnd:
|
||||
return "StateInValueEnd"
|
||||
case StateInListEnd:
|
||||
return "StateInListEnd"
|
||||
case StateInListObjectEnd:
|
||||
return "StateInListObjectEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInNumber:
|
||||
return "StateInNumber"
|
||||
case StateInNumberEnd:
|
||||
return "StateInNumberEnd"
|
||||
case StateInStringEnd:
|
||||
return "StateInStringEnd"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInObjectEnd:
|
||||
return "StateInObjectEnd"
|
||||
case StateTransitioningToTerminate:
|
||||
return "StateTransitioningToTerminate"
|
||||
case StateInListStartJSON:
|
||||
return "StateInListStartJSON"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown state: %d", s)
|
||||
}
|
||||
}
|
327
sample/pushdown_automata.go
Normal file
327
sample/pushdown_automata.go
Normal file
@ -0,0 +1,327 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
/*
|
||||
Key JSON rules to consider:
|
||||
|
||||
1. Whitespace handling:
|
||||
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
||||
- Current code only handles some whitespace cases
|
||||
|
||||
2. Number validation:
|
||||
- Need proper validation for special number cases like -0
|
||||
- Should handle .5 style decimals
|
||||
- Need limits on scientific notation (e, E)
|
||||
|
||||
3. String escaping:
|
||||
- Currently marks \ as invalid but should allow escaped sequences:
|
||||
- \"
|
||||
- \n
|
||||
- \u1234 unicode escapes
|
||||
|
||||
4. Empty object/array transitions:
|
||||
- Direct {} and [] cases could be more explicit
|
||||
- Need clear transitions for these edge cases
|
||||
|
||||
5. Nested depth limits:
|
||||
- No protection against excessive nesting
|
||||
- Could cause stack overflow with deeply nested structures
|
||||
*/
|
||||
|
||||
// TODO: / should be valid but an escape character
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||
|
||||
var (
|
||||
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
)
|
||||
|
||||
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||
|
||||
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||
|
||||
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
|
||||
type PDA struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDA
|
||||
MaskTokenIDToNode map[int32]*PDA
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDA {
|
||||
return &PDA{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
}
|
||||
|
||||
type PDAGraphBuilder struct {
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
stateToNodeMap map[JSONState]*PDA
|
||||
tokenToStatesMap map[int32][]JSONState
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||
stateToNodeMap := make(map[JSONState]*PDA)
|
||||
for _, state := range JSONStates {
|
||||
stateToNodeMap[state] = NewPDANode(state)
|
||||
}
|
||||
|
||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||
|
||||
// TODO: update naming here - and revisit values
|
||||
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||
|
||||
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// new line
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
|
||||
// new line end value
|
||||
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
// TODO: see if this is needed for formatting
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
|
||||
|
||||
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
|
||||
|
||||
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
|
||||
|
||||
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// where values should be
|
||||
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||
|
||||
// Leads to a value
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
// Values
|
||||
// string node
|
||||
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
|
||||
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||
|
||||
// String end node
|
||||
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list node
|
||||
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
// early end
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
// list end node
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// empty list
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
|
||||
|
||||
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||
|
||||
// list object end
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
// TODO: not sure if this is needed
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// comma node
|
||||
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
// todo: review this space transition
|
||||
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
// space end value
|
||||
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
b.stateToNodeMap = stateToNodeMap
|
||||
if err := b.preComputeValidStates(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
}
|
||||
|
||||
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||
for _, r := range validNumberRunes {
|
||||
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||
for _, node := range b.stateToNodeMap {
|
||||
// if node.State == StateInObjectKey {
|
||||
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
|
||||
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
|
||||
// fmt.Println("copying string mask to object key mask")
|
||||
// }
|
||||
// }
|
||||
if err := b.CreateMask(node); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
|
||||
// TODO: make can be somewhere else too
|
||||
b.tokenToStatesMap = make(map[int32][]JSONState)
|
||||
for i, t := range b.decodedToks {
|
||||
for _, r := range t {
|
||||
if r == '"' {
|
||||
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: the mask for obj key and string should be the same?
|
||||
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||
if node == nil {
|
||||
return fmt.Errorf("node cannot be nil")
|
||||
}
|
||||
for i := range b.decodedToks {
|
||||
token := b.decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
curNode := node
|
||||
valid := true
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if curNode == nil || !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
// fmt.Println("next node", nextNode)
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return nil, false
|
||||
}
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
264
sample/pushdown_runner.go
Normal file
264
sample/pushdown_runner.go
Normal file
@ -0,0 +1,264 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// TODO: safety in case of invalid json
|
||||
// TODO: partial JSON matching?
|
||||
// TODO: interfaces to cleanup with return values
|
||||
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||
// TODO: add penalties for string \n stuff
|
||||
// TODO: minimize number of fwd passes if there is only one match
|
||||
// TODO: greedy sample initially and then backtrack if no match
|
||||
|
||||
type PushdownSampler struct {
|
||||
PDAGraphBuilder
|
||||
curNode *PDA
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
||||
start := time.Now()
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("PDA sampler")
|
||||
fmt.Println("--------------------------------")
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
vocab := proc.Vocab()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
gb := &PDAGraphBuilder{
|
||||
proc: proc,
|
||||
decodedToks: decodedToks,
|
||||
}
|
||||
|
||||
if err := gb.BuildGraph(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
|
||||
// TODO: this can be simplified
|
||||
return &PushdownSampler{
|
||||
curNode: gb.stateToNodeMap[StateStart],
|
||||
PDAGraphBuilder: *gb,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
|
||||
switch s.curNode.State {
|
||||
case StateInString:
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInListEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
return forceFinish(s, logits)
|
||||
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInComma:
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
}
|
||||
|
||||
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
|
||||
for i := range logits {
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||
fmt.Println("current state - updating", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||
|
||||
// Special handling for EOS token in terminate state
|
||||
if s.curNode.State == StateTerminate {
|
||||
for _, tokenID := range tokenSlice {
|
||||
if s.proc.Is(tokenID, model.SpecialEOS) {
|
||||
return tokenSlice, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flag := -1
|
||||
// endBraceRunes := []rune{'}', ']'}
|
||||
for _, r := range mappedString {
|
||||
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
||||
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
||||
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
||||
// // flag = i
|
||||
// break
|
||||
|
||||
// }
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('[') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if top != rune('{') {
|
||||
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if top != rune('[') {
|
||||
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// if flag != -1 {
|
||||
// tokenSlice = tokenSlice[:flag]
|
||||
// }
|
||||
// fmt.Println("flag!", flag)
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNode.State)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNode.State == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = nextNode
|
||||
fmt.Println("updated curNode state", s.curNode.State)
|
||||
}
|
||||
return tokenSlice, nil
|
||||
}
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float32, len(logits))
|
||||
for i := range maskedLogits {
|
||||
maskedLogits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
// Only update values for valid token IDs from the mask map
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) {
|
||||
maskedLogits[tokenID] = logits[tokenID]
|
||||
}
|
||||
}
|
||||
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
maxIndex := -1
|
||||
|
||||
// Find the maximum logit value among valid tokens
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
||||
maxLogit = logits[tokenID]
|
||||
maxIndex = int(tokenID)
|
||||
}
|
||||
}
|
||||
|
||||
if maxIndex == -1 {
|
||||
return nil, fmt.Errorf("no valid tokens found in mask")
|
||||
}
|
||||
|
||||
logits[0] = float32(maxIndex)
|
||||
return logits, nil
|
||||
// return maxIndex, nil
|
||||
}
|
@ -17,12 +17,14 @@ type token struct {
|
||||
}
|
||||
|
||||
type Sampler struct {
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
JSONSampler *JSONSampler
|
||||
PythonSampler *PythonSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
var err error
|
||||
if s.JSONSampler != nil {
|
||||
logits, err = s.JSONSampler.Apply(logits)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
if s.PythonSampler != nil {
|
||||
logits, err = s.PythonSampler.ApplyMask(logits)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
tokens[i].id = int32(i)
|
||||
@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
grammar: grammar,
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
grammar: grammar,
|
||||
JSONSampler: jsonSampler,
|
||||
PythonSampler: pythonSampler,
|
||||
}
|
||||
}
|
||||
|
||||
|
299
sample/structured_outputs.go
Normal file
299
sample/structured_outputs.go
Normal file
@ -0,0 +1,299 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/grammar/jsonschema"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONSampler struct {
|
||||
schema *jsonschema.Schema
|
||||
propIdx int
|
||||
propToNodeMap map[string]*PDA
|
||||
pdaSampler *PushdownSampler
|
||||
decodedToks []string
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
|
||||
slog.Info("NewJSONSampler", "schema", schema)
|
||||
if proc == nil {
|
||||
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
||||
}
|
||||
|
||||
pdaSampler, err := NewPushdownSampler(proc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
||||
}
|
||||
|
||||
if schema == nil {
|
||||
return &JSONSampler{
|
||||
schema: nil,
|
||||
propIdx: -1,
|
||||
propToNodeMap: nil,
|
||||
pdaSampler: pdaSampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fmt.Println("schema not nil")
|
||||
so := &JSONSampler{
|
||||
schema: schema,
|
||||
propIdx: -1,
|
||||
propToNodeMap: make(map[string]*PDA),
|
||||
pdaSampler: pdaSampler,
|
||||
}
|
||||
|
||||
so.schemaToGraph()
|
||||
|
||||
// Benchmark token decoding
|
||||
start := time.Now()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
|
||||
vocab := proc.Vocab()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
so.decodedToks = decodedToks
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("SOSampler")
|
||||
fmt.Println("--------------------------------")
|
||||
// Benchmark this section
|
||||
start = time.Now()
|
||||
runtime.ReadMemStats(&m)
|
||||
before = m.Alloc
|
||||
|
||||
// TODO: still messed up
|
||||
// TODO: recursion use case
|
||||
// key masks
|
||||
for _, prop := range so.schema.Properties {
|
||||
node := so.propToNodeMap[prop.Name]
|
||||
// propName -> node
|
||||
curState := node.State
|
||||
fromNode := node
|
||||
so.pdaSampler.CreateMask(fromNode)
|
||||
for curState == StateInStructuredKey {
|
||||
// there is only one edge
|
||||
for r, toNode := range fromNode.TransitionEdges {
|
||||
fmt.Println("rune", r, "edge", toNode.State)
|
||||
so.pdaSampler.CreateMask(toNode)
|
||||
fmt.Printf("created mask for %c\n", r)
|
||||
curState = toNode.State
|
||||
fmt.Println("next state", curState)
|
||||
// TODO: theres an extra gen for " right now
|
||||
fromNode = toNode
|
||||
}
|
||||
}
|
||||
|
||||
if curState != StateInColon {
|
||||
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
||||
}
|
||||
|
||||
// so.pdaSampler.CreateMask(fromNode)
|
||||
|
||||
fromNode = fromNode.TransitionEdges[' ']
|
||||
|
||||
so.pdaSampler.CreateMask(fromNode)
|
||||
curState = fromNode.State
|
||||
for _, toNode := range fromNode.TransitionEdges {
|
||||
fmt.Println("toNode", toNode.State)
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.ReadMemStats(&m)
|
||||
// after = m.Alloc
|
||||
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||
// fmt.Println("--------------------------------")
|
||||
|
||||
return so, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) schemaToGraph() {
|
||||
schemaType := s.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// TODO: see if we need to connect these to the JSON graph
|
||||
|
||||
// each prop is a key
|
||||
for _, prop := range s.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
keyNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
|
||||
prevNode := keyNode
|
||||
for _, r := range name {
|
||||
runeNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
// fmt.Println("runeNode created", runeNode.State)
|
||||
// fmt.Printf("runeNode created %c\n", r)
|
||||
|
||||
// since alloc on heap connections wil still map
|
||||
prevNode.TransitionEdges[r] = runeNode
|
||||
prevNode = runeNode
|
||||
}
|
||||
|
||||
// point to end of object key node after all chars are done
|
||||
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||
|
||||
// link to value node
|
||||
// Create a node for the end of the key (after the closing quote)
|
||||
stringEndNode := &PDA{
|
||||
State: StateInStructuredKey,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges['"'] = stringEndNode
|
||||
prevNode = stringEndNode
|
||||
|
||||
// Add transition for colon after key
|
||||
colonNode := &PDA{
|
||||
State: StateInColon,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges[':'] = colonNode
|
||||
prevNode = colonNode
|
||||
|
||||
// Add transition for space after colon
|
||||
spaceNode := &PDA{
|
||||
State: StateInSpaceToValue,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges[' '] = spaceNode
|
||||
prevNode = spaceNode
|
||||
|
||||
value := prop.Type
|
||||
switch value {
|
||||
case "object":
|
||||
fmt.Println("object under key: ", name)
|
||||
case "array":
|
||||
fmt.Println("array under key: ", name)
|
||||
case "string":
|
||||
fmt.Println("string under key: ", name)
|
||||
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
||||
case "number":
|
||||
fmt.Println("number under key: ", name)
|
||||
for _, r := range validNumberRunes {
|
||||
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
||||
}
|
||||
case "boolean":
|
||||
fmt.Println("boolean under key: ", name)
|
||||
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
// points to start of the key
|
||||
s.propToNodeMap[name] = keyNode
|
||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||
}
|
||||
}
|
||||
// TODO: do values + recursion
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
|
||||
if s.schema == nil {
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
// TODO: doesnt account for multi rune case
|
||||
case StateInObjectKey:
|
||||
if s.propIdx > len(s.schema.Properties)-1 {
|
||||
return nil, fmt.Errorf("propIdx out of bounds")
|
||||
}
|
||||
// fmt.Println("in object key - structured outputs")
|
||||
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||
// simple case
|
||||
s.propIdx++
|
||||
fmt.Println("propIdx", s.propIdx)
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
|
||||
// Will only happen for the last prop - can also be precomputed.
|
||||
if s.propIdx == len(s.schema.Properties)-1 {
|
||||
// todo: if i incremenet propidx then i know im in last value as well
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectEnd:
|
||||
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
||||
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
||||
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||
s.propIdx++
|
||||
|
||||
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
||||
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
||||
|
||||
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
||||
s.propIdx++
|
||||
}
|
||||
}
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.schema == nil {
|
||||
// Don't need to update state for unconstrained JSON sampling
|
||||
return tokenSlice, nil
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectKey:
|
||||
s.propIdx++
|
||||
fmt.Println("propIdx", s.propIdx)
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||
// TODO: this does not work - mike
|
||||
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// fmt.Println("str", str)
|
||||
|
||||
return tokenSlice, nil
|
||||
default:
|
||||
return tokenSlice, nil
|
||||
}
|
||||
}
|
352
sample/structured_python.go
Normal file
352
sample/structured_python.go
Normal file
@ -0,0 +1,352 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type PythonState int
|
||||
|
||||
const (
|
||||
PythonStateStart PythonState = iota
|
||||
StateInFunction
|
||||
StateInFunctionArgs
|
||||
StateInFunctionArgsType
|
||||
StateInFunctionEnd
|
||||
PStateInString
|
||||
PStateInStringEnd
|
||||
PStateInNumber
|
||||
PStateInList
|
||||
PStateInListEnd
|
||||
PStateInDict
|
||||
PStateInDictEnd
|
||||
PStateInTuple
|
||||
PStateInTupleEnd
|
||||
PStateTerminate
|
||||
)
|
||||
|
||||
func (s PythonState) String() string {
|
||||
switch s {
|
||||
case PythonStateStart:
|
||||
return "PythonStateStart"
|
||||
case StateInFunction:
|
||||
return "StateInFunction"
|
||||
case StateInFunctionArgs:
|
||||
return "StateInFunctionArgs"
|
||||
case StateInFunctionArgsType:
|
||||
return "StateInFunctionArgsType"
|
||||
case StateInFunctionEnd:
|
||||
return "StateInFunctionEnd"
|
||||
case PStateInString:
|
||||
return "PStateInString"
|
||||
case PStateInStringEnd:
|
||||
return "PStateInStringEnd"
|
||||
case PStateInNumber:
|
||||
return "PStateInNumber"
|
||||
case PStateInList:
|
||||
return "PStateInList"
|
||||
case PStateInListEnd:
|
||||
return "PStateInListEnd"
|
||||
case PStateInDict:
|
||||
return "PStateInDict"
|
||||
case PStateInDictEnd:
|
||||
return "PStateInDictEnd"
|
||||
case PStateInTuple:
|
||||
return "PStateInTuple"
|
||||
case PStateInTupleEnd:
|
||||
return "PStateInTupleEnd"
|
||||
case PStateTerminate:
|
||||
return "PStateTerminate"
|
||||
default:
|
||||
return fmt.Sprintf("PythonState(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
var PythonStates = []PythonState{
|
||||
PythonStateStart,
|
||||
StateInFunction,
|
||||
StateInFunctionArgs,
|
||||
StateInFunctionArgsType,
|
||||
StateInFunctionEnd,
|
||||
PStateInString,
|
||||
PStateInStringEnd,
|
||||
PStateInNumber,
|
||||
PStateInList,
|
||||
PStateInListEnd,
|
||||
PStateInDict,
|
||||
PStateInDictEnd,
|
||||
PStateInTuple,
|
||||
PStateInTupleEnd,
|
||||
PStateTerminate,
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
State PythonState
|
||||
TransitionEdges map[rune]*Node
|
||||
MaskTokenIDToNode map[int32]*Node
|
||||
}
|
||||
|
||||
func NewNode(state PythonState) *Node {
|
||||
return &Node{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
MaskTokenIDToNode: make(map[int32]*Node),
|
||||
}
|
||||
}
|
||||
|
||||
type PythonFunction struct {
|
||||
Name string
|
||||
Args []string
|
||||
Types []string
|
||||
}
|
||||
|
||||
type PythonSampler struct {
|
||||
stateToNodes map[PythonState]*Node
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
curNode *Node
|
||||
completed int
|
||||
functions []PythonFunction
|
||||
}
|
||||
|
||||
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||
s.proc = proc
|
||||
s.functions = functions
|
||||
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||
for i := range proc.Vocab().Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
s.decodedToks = decodedToks
|
||||
s.BuildGraph()
|
||||
for _, function := range functions {
|
||||
prevNode := s.stateToNodes[PythonStateStart]
|
||||
|
||||
for _, r := range function.Name {
|
||||
nextNode := NewNode(StateInFunction)
|
||||
prevNode.TransitionEdges[r] = nextNode
|
||||
if err := s.CreateMask(nextNode); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("prevNode", prevNode.State)
|
||||
fmt.Printf("transition edge: %q\n", r)
|
||||
fmt.Println("nextNode", nextNode.State)
|
||||
prevNode = nextNode
|
||||
}
|
||||
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||
for i, arg := range function.Args {
|
||||
for _, r := range arg {
|
||||
nextNode := NewNode(StateInFunctionArgs)
|
||||
prevNode.TransitionEdges[r] = nextNode
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = nextNode
|
||||
}
|
||||
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
// prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = prevNode.TransitionEdges['=']
|
||||
switch function.Types[i] {
|
||||
case "string":
|
||||
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
|
||||
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||
case "number":
|
||||
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
|
||||
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
s.curNode = s.stateToNodes[PythonStateStart]
|
||||
fmt.Println("curNode", s.curNode.State)
|
||||
fmt.Println("transition edges", s.curNode.TransitionEdges)
|
||||
if err := s.CreateMask(s.curNode); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
|
||||
for tokenID, node := range s.curNode.MaskTokenIDToNode {
|
||||
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) BuildGraph() error {
|
||||
s.stateToNodes = make(map[PythonState]*Node)
|
||||
for _, state := range PythonStates {
|
||||
s.stateToNodes[state] = NewNode(state)
|
||||
}
|
||||
|
||||
for _, state := range s.stateToNodes {
|
||||
if err := s.CreateMask(state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// String
|
||||
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
|
||||
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
|
||||
|
||||
// String end
|
||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
// Number
|
||||
for _, r := range validNumberRunes {
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||
}
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
|
||||
if s.curNode.State == PStateTerminate {
|
||||
logits, err := finish(s, logits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) UpdateState(token int32) error {
|
||||
mappedString, err := s.proc.Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||
|
||||
if s.curNode.State == PStateTerminate {
|
||||
if s.proc.Is(token, model.SpecialEOS) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
|
||||
if mappedString == "\"" {
|
||||
if s.curNode.State == PStateInStringEnd {
|
||||
s.completed++
|
||||
}
|
||||
if s.completed == len(s.functions) {
|
||||
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
s.CreateMask(s.curNode)
|
||||
}
|
||||
}
|
||||
s.curNode = nextNode
|
||||
fmt.Println("curNode", s.curNode.State)
|
||||
for r, node := range s.curNode.TransitionEdges {
|
||||
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
|
||||
}
|
||||
if err := s.CreateMask(s.curNode); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) CreateMask(node *Node) error {
|
||||
if node == nil {
|
||||
return fmt.Errorf("node cannot be nil")
|
||||
}
|
||||
for i := range s.decodedToks {
|
||||
token := s.decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
curNode := node
|
||||
valid := true
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
|
||||
if curNode == nil || !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
if curNode.State == StateInFunction {
|
||||
// fmt.Println("cm curNode", curNode.State)
|
||||
// fmt.Println("cm token", s.decodedToks[i])
|
||||
}
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
// fmt.Println("next node", nextNode)
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return nil, false
|
||||
}
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float32, len(logits))
|
||||
for i := range maskedLogits {
|
||||
maskedLogits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
// Only update values for valid token IDs from the mask map
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) {
|
||||
maskedLogits[tokenID] = logits[tokenID]
|
||||
}
|
||||
}
|
||||
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
|
||||
for i := range logits {
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user