Compare commits
26 Commits
v0.6.2
...
parth/samp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4450f871db | ||
|
|
5ec6bb52a0 | ||
|
|
1fd9967558 | ||
|
|
131f0355a5 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 |
@@ -512,6 +512,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
@@ -703,6 +703,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
var v string
|
var v string
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
case string:
|
case string:
|
||||||
v = vData
|
v = vData
|
||||||
case float64:
|
case float64:
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ModelInfo: map[string]any{
|
ModelInfo: map[string]any{
|
||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.parameter_count": float64(8_000_000_000),
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
"test.context_length": float64(1000),
|
"test.context_length": float64(1000),
|
||||||
"test.embedding_length": float64(11434),
|
"test.embedding_length": float64(11434),
|
||||||
},
|
},
|
||||||
@@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Metadata
|
Metadata
|
||||||
general.architecture test
|
general.architecture test
|
||||||
general.parameter_count 8e+09
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
test.context_length 1000
|
test.context_length 1000
|
||||||
test.embedding_length 11434
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
|
|||||||
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
22
grammar/bench_test.go
Normal file
22
grammar/bench_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
//go:build go1.24
|
||||||
|
|
||||||
|
package grammar
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func BenchmarkFromSchema(b *testing.B) {
|
||||||
|
for tt := range testCases(b) {
|
||||||
|
b.Run("", func(b *testing.B) {
|
||||||
|
s := []byte(tt.schema)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := FromSchema(nil, s)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("GrammarFromSchema: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
227
grammar/grammar.go
Normal file
227
grammar/grammar.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/jsonschema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const jsonTerms = `
|
||||||
|
# Unicode
|
||||||
|
#
|
||||||
|
# Unicode characters can be specified directly in the grammar, for example
|
||||||
|
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
|
||||||
|
# (\UXXXXXXXX).
|
||||||
|
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
|
||||||
|
|
||||||
|
# JSON grammar from RFC 7159
|
||||||
|
null ::= "null"
|
||||||
|
object ::= "{" (kv ("," kv)*)? "}"
|
||||||
|
array ::= "[" (value ("," value)*)? "]"
|
||||||
|
kv ::= string ":" value
|
||||||
|
integer ::= "0" | [1-9] [0-9]*
|
||||||
|
number ::= "-"? integer frac? exp?
|
||||||
|
frac ::= "." [0-9]+
|
||||||
|
exp ::= ("e" | "E") ("+" | "-") [0-9]+
|
||||||
|
string ::= "\"" char* "\""
|
||||||
|
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
|
||||||
|
char ::= [^"\\] | escape
|
||||||
|
space ::= (" " | "\t" | "\n" | "\r")*
|
||||||
|
hex ::= [0-9] | [a-f] | [A-F]
|
||||||
|
boolean ::= "true" | "false"
|
||||||
|
value ::= object | array | string | number | boolean | "null"
|
||||||
|
|
||||||
|
# User-defined
|
||||||
|
`
|
||||||
|
|
||||||
|
// FromSchema generates a grammar from a JSON schema.
|
||||||
|
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
|
||||||
|
var s *jsonschema.Schema
|
||||||
|
if err := json.Unmarshal(jsonSchema, &s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var g builder
|
||||||
|
|
||||||
|
// "root" is the only rule that is guaranteed to exist, so we start
|
||||||
|
// with its length for padding, and then adjust it as we go.
|
||||||
|
g.pad = len("root")
|
||||||
|
for id := range dependencies("root", s) {
|
||||||
|
g.pad = max(g.pad, len(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
g.b.WriteString(jsonTerms)
|
||||||
|
|
||||||
|
ids := make(map[*jsonschema.Schema]string)
|
||||||
|
for id, s := range dependencies("root", s) {
|
||||||
|
ids[s] = id
|
||||||
|
g.define(id)
|
||||||
|
if err := fromSchema(&g, ids, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.define("root")
|
||||||
|
if err := fromSchema(&g, ids, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
g.define("") // finalize the last rule
|
||||||
|
return g.b.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
|
||||||
|
switch typ := s.EffectiveType(); typ {
|
||||||
|
case "array":
|
||||||
|
if len(s.PrefixItems) == 0 && s.Items == nil {
|
||||||
|
g.u("array")
|
||||||
|
} else {
|
||||||
|
g.q("[")
|
||||||
|
for i, s := range s.PrefixItems {
|
||||||
|
if i > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.u(ids[s])
|
||||||
|
}
|
||||||
|
if s.Items != nil {
|
||||||
|
g.u("(")
|
||||||
|
if len(s.PrefixItems) > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.u(ids[s.Items])
|
||||||
|
g.u(")*")
|
||||||
|
}
|
||||||
|
g.q("]")
|
||||||
|
}
|
||||||
|
case "object":
|
||||||
|
if len(s.Properties) == 0 {
|
||||||
|
g.u("object")
|
||||||
|
} else {
|
||||||
|
g.q("{")
|
||||||
|
for i, p := range s.Properties {
|
||||||
|
name := ids[p]
|
||||||
|
if i > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.q(p.Name)
|
||||||
|
g.q(":")
|
||||||
|
g.u(name)
|
||||||
|
}
|
||||||
|
g.q("}")
|
||||||
|
}
|
||||||
|
case "number":
|
||||||
|
buildConstrainedNumber(g, s)
|
||||||
|
case "string":
|
||||||
|
if len(s.Enum) == 0 {
|
||||||
|
g.u("string")
|
||||||
|
} else {
|
||||||
|
g.u("(")
|
||||||
|
for i, e := range s.Enum {
|
||||||
|
if i > 0 {
|
||||||
|
g.q("|")
|
||||||
|
}
|
||||||
|
g.q(string(e))
|
||||||
|
}
|
||||||
|
g.u(")")
|
||||||
|
}
|
||||||
|
case "boolean", "value", "null", "integer":
|
||||||
|
g.u(typ)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dependencies returns a sequence of all child dependencies of the schema in
|
||||||
|
// post-order.
|
||||||
|
//
|
||||||
|
// The first value is the id/pointer to the dependency, and the second value
|
||||||
|
// is the schema.
|
||||||
|
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
|
||||||
|
return func(yield func(string, *jsonschema.Schema) bool) {
|
||||||
|
for i, p := range s.Properties {
|
||||||
|
id := fmt.Sprintf("%s_%d", id, i)
|
||||||
|
for did, d := range dependencies(id, p) {
|
||||||
|
if !yield(did, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, p := range s.PrefixItems {
|
||||||
|
id := fmt.Sprintf("tuple_%d", i)
|
||||||
|
for did, d := range dependencies(id, p) {
|
||||||
|
id := fmt.Sprintf("%s_%s", id, did)
|
||||||
|
if !yield(id, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.Items != nil {
|
||||||
|
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
|
||||||
|
for did, d := range dependencies(id, s.Items) {
|
||||||
|
if !yield(did, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, s.Items) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type builder struct {
|
||||||
|
b bytes.Buffer
|
||||||
|
pad int
|
||||||
|
rules int
|
||||||
|
items int
|
||||||
|
}
|
||||||
|
|
||||||
|
// define terminates the current rule, if any, and then either starts a new
|
||||||
|
// rule or does nothing else if the name is empty.
|
||||||
|
func (b *builder) define(name string) {
|
||||||
|
if b.rules > 0 {
|
||||||
|
b.b.WriteString(";\n")
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
|
||||||
|
b.b.WriteString(" ::=")
|
||||||
|
b.rules++
|
||||||
|
b.items = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// quote appends a terminal to the current rule.
|
||||||
|
func (b *builder) q(s string) {
|
||||||
|
if b.items > 0 {
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
}
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
b.b.WriteString(strconv.Quote(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// u appends a non-terminal to the current rule.
|
||||||
|
func (b *builder) u(s string) {
|
||||||
|
if b.items > 0 {
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
}
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
b.b.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
|
||||||
|
if s.Minimum == 0 && s.Maximum == 0 {
|
||||||
|
b.u("TODO")
|
||||||
|
} else {
|
||||||
|
b.u("number")
|
||||||
|
}
|
||||||
|
}
|
||||||
75
grammar/grammar_test.go
Normal file
75
grammar/grammar_test.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"cmp"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/internal/diff"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFromSchema(t *testing.T) {
|
||||||
|
for tt := range testCases(t) {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
g, err := FromSchema(nil, []byte(tt.schema))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FromSchema: %v", err)
|
||||||
|
}
|
||||||
|
got := string(g)
|
||||||
|
got = strings.TrimPrefix(got, jsonTerms)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Logf("schema:\n%s", tt.schema)
|
||||||
|
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
schema string
|
||||||
|
want string
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed testdata/schemas.txt
|
||||||
|
var tests string
|
||||||
|
|
||||||
|
func testCases(t testing.TB) iter.Seq[testCase] {
|
||||||
|
t.Helper()
|
||||||
|
return func(yield func(testCase) bool) {
|
||||||
|
t.Helper()
|
||||||
|
sc := bufio.NewScanner(strings.NewReader(tests))
|
||||||
|
name := ""
|
||||||
|
for sc.Scan() {
|
||||||
|
line := strings.TrimSpace(sc.Text())
|
||||||
|
if line == "" {
|
||||||
|
name = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if line[0] == '#' {
|
||||||
|
name = cmp.Or(name, strings.TrimSpace(line[1:]))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s := sc.Text()
|
||||||
|
g := ""
|
||||||
|
for sc.Scan() {
|
||||||
|
line = strings.TrimSpace(sc.Text())
|
||||||
|
if line == "" || line[0] == '#' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
g += sc.Text() + "\n"
|
||||||
|
}
|
||||||
|
if !yield(testCase{name, s, g}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
|
||||||
|
}
|
||||||
|
if err := sc.Err(); err != nil {
|
||||||
|
t.Fatalf("error reading tests: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
261
grammar/internal/diff/diff.go
Normal file
261
grammar/internal/diff/diff.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
// Copyright 2022 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A pair is a pair of values tracked for both the x and y side of a diff.
|
||||||
|
// It is typically a pair of line indexes.
|
||||||
|
type pair struct{ x, y int }
|
||||||
|
|
||||||
|
// Diff returns an anchored diff of the two texts old and new
|
||||||
|
// in the “unified diff” format. If old and new are identical,
|
||||||
|
// Diff returns a nil slice (no output).
|
||||||
|
//
|
||||||
|
// Unix diff implementations typically look for a diff with
|
||||||
|
// the smallest number of lines inserted and removed,
|
||||||
|
// which can in the worst case take time quadratic in the
|
||||||
|
// number of lines in the texts. As a result, many implementations
|
||||||
|
// either can be made to run for a long time or cut off the search
|
||||||
|
// after a predetermined amount of work.
|
||||||
|
//
|
||||||
|
// In contrast, this implementation looks for a diff with the
|
||||||
|
// smallest number of “unique” lines inserted and removed,
|
||||||
|
// where unique means a line that appears just once in both old and new.
|
||||||
|
// We call this an “anchored diff” because the unique lines anchor
|
||||||
|
// the chosen matching regions. An anchored diff is usually clearer
|
||||||
|
// than a standard diff, because the algorithm does not try to
|
||||||
|
// reuse unrelated blank lines or closing braces.
|
||||||
|
// The algorithm also guarantees to run in O(n log n) time
|
||||||
|
// instead of the standard O(n²) time.
|
||||||
|
//
|
||||||
|
// Some systems call this approach a “patience diff,” named for
|
||||||
|
// the “patience sorting” algorithm, itself named for a solitaire card game.
|
||||||
|
// We avoid that name for two reasons. First, the name has been used
|
||||||
|
// for a few different variants of the algorithm, so it is imprecise.
|
||||||
|
// Second, the name is frequently interpreted as meaning that you have
|
||||||
|
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
|
||||||
|
// when in fact the algorithm is faster than the standard one.
|
||||||
|
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
|
||||||
|
if bytes.Equal(old, new) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
x := lines(old)
|
||||||
|
y := lines(new)
|
||||||
|
|
||||||
|
// Print diff header.
|
||||||
|
var out bytes.Buffer
|
||||||
|
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
|
||||||
|
fmt.Fprintf(&out, "--- %s\n", oldName)
|
||||||
|
fmt.Fprintf(&out, "+++ %s\n", newName)
|
||||||
|
|
||||||
|
// Loop over matches to consider,
|
||||||
|
// expanding each match to include surrounding lines,
|
||||||
|
// and then printing diff chunks.
|
||||||
|
// To avoid setup/teardown cases outside the loop,
|
||||||
|
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
|
||||||
|
// in the sequence of matches.
|
||||||
|
var (
|
||||||
|
done pair // printed up to x[:done.x] and y[:done.y]
|
||||||
|
chunk pair // start lines of current chunk
|
||||||
|
count pair // number of lines from each side in current chunk
|
||||||
|
ctext []string // lines for current chunk
|
||||||
|
)
|
||||||
|
for _, m := range tgs(x, y) {
|
||||||
|
if m.x < done.x {
|
||||||
|
// Already handled scanning forward from earlier match.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand matching lines as far as possible,
|
||||||
|
// establishing that x[start.x:end.x] == y[start.y:end.y].
|
||||||
|
// Note that on the first (or last) iteration we may (or definitely do)
|
||||||
|
// have an empty match: start.x==end.x and start.y==end.y.
|
||||||
|
start := m
|
||||||
|
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
|
||||||
|
start.x--
|
||||||
|
start.y--
|
||||||
|
}
|
||||||
|
end := m
|
||||||
|
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
|
||||||
|
end.x++
|
||||||
|
end.y++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit the mismatched lines before start into this chunk.
|
||||||
|
// (No effect on first sentinel iteration, when start = {0,0}.)
|
||||||
|
for _, s := range x[done.x:start.x] {
|
||||||
|
ctext = append(ctext, "-"+s)
|
||||||
|
count.x++
|
||||||
|
}
|
||||||
|
for _, s := range y[done.y:start.y] {
|
||||||
|
ctext = append(ctext, "+"+s)
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're not at EOF and have too few common lines,
|
||||||
|
// the chunk includes all the common lines and continues.
|
||||||
|
const C = 3 // number of context lines
|
||||||
|
if (end.x < len(x) || end.y < len(y)) &&
|
||||||
|
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
|
||||||
|
for _, s := range x[start.x:end.x] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// End chunk with common lines for context.
|
||||||
|
if len(ctext) > 0 {
|
||||||
|
n := end.x - start.x
|
||||||
|
if n > C {
|
||||||
|
n = C
|
||||||
|
}
|
||||||
|
for _, s := range x[start.x : start.x+n] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = pair{start.x + n, start.y + n}
|
||||||
|
|
||||||
|
// Format and emit chunk.
|
||||||
|
// Convert line numbers to 1-indexed.
|
||||||
|
// Special case: empty file shows up as 0,0 not 1,0.
|
||||||
|
if count.x > 0 {
|
||||||
|
chunk.x++
|
||||||
|
}
|
||||||
|
if count.y > 0 {
|
||||||
|
chunk.y++
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
|
||||||
|
for _, s := range ctext {
|
||||||
|
out.WriteString(s)
|
||||||
|
}
|
||||||
|
count.x = 0
|
||||||
|
count.y = 0
|
||||||
|
ctext = ctext[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reached EOF, we're done.
|
||||||
|
if end.x >= len(x) && end.y >= len(y) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise start a new chunk.
|
||||||
|
chunk = pair{end.x - C, end.y - C}
|
||||||
|
for _, s := range x[chunk.x:end.x] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = end
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lines returns the lines in the file x, including newlines.
|
||||||
|
// If the file does not end in a newline, one is supplied
|
||||||
|
// along with a warning about the missing newline.
|
||||||
|
func lines(x []byte) []string {
|
||||||
|
l := strings.SplitAfter(string(x), "\n")
|
||||||
|
if l[len(l)-1] == "" {
|
||||||
|
l = l[:len(l)-1]
|
||||||
|
} else {
|
||||||
|
// Treat last line as having a message about the missing newline attached,
|
||||||
|
// using the same text as BSD/GNU diff (including the leading backslash).
|
||||||
|
l[len(l)-1] += "\n\\ No newline at end of file\n"
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// tgs returns the pairs of indexes of the longest common subsequence
|
||||||
|
// of unique lines in x and y, where a unique line is one that appears
|
||||||
|
// once in x and once in y.
|
||||||
|
//
|
||||||
|
// The longest common subsequence algorithm is as described in
|
||||||
|
// Thomas G. Szymanski, “A Special Case of the Maximal Common
|
||||||
|
// Subsequence Problem,” Princeton TR #170 (January 1975),
|
||||||
|
// available at https://research.swtch.com/tgs170.pdf.
|
||||||
|
func tgs(x, y []string) []pair {
|
||||||
|
// Count the number of times each string appears in a and b.
|
||||||
|
// We only care about 0, 1, many, counted as 0, -1, -2
|
||||||
|
// for the x side and 0, -4, -8 for the y side.
|
||||||
|
// Using negative numbers now lets us distinguish positive line numbers later.
|
||||||
|
m := make(map[string]int)
|
||||||
|
for _, s := range x {
|
||||||
|
if c := m[s]; c > -2 {
|
||||||
|
m[s] = c - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, s := range y {
|
||||||
|
if c := m[s]; c > -8 {
|
||||||
|
m[s] = c - 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now unique strings can be identified by m[s] = -1+-4.
|
||||||
|
//
|
||||||
|
// Gather the indexes of those strings in x and y, building:
|
||||||
|
// xi[i] = increasing indexes of unique strings in x.
|
||||||
|
// yi[i] = increasing indexes of unique strings in y.
|
||||||
|
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
|
||||||
|
var xi, yi, inv []int
|
||||||
|
for i, s := range y {
|
||||||
|
if m[s] == -1+-4 {
|
||||||
|
m[s] = len(yi)
|
||||||
|
yi = append(yi, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, s := range x {
|
||||||
|
if j, ok := m[s]; ok && j >= 0 {
|
||||||
|
xi = append(xi, i)
|
||||||
|
inv = append(inv, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Algorithm A from Szymanski's paper.
|
||||||
|
// In those terms, A = J = inv and B = [0, n).
|
||||||
|
// We add sentinel pairs {0,0}, and {len(x),len(y)}
|
||||||
|
// to the returned sequence, to help the processing loop.
|
||||||
|
J := inv
|
||||||
|
n := len(xi)
|
||||||
|
T := make([]int, n)
|
||||||
|
L := make([]int, n)
|
||||||
|
for i := range T {
|
||||||
|
T[i] = n + 1
|
||||||
|
}
|
||||||
|
for i := range n {
|
||||||
|
k := sort.Search(n, func(k int) bool {
|
||||||
|
return T[k] >= J[i]
|
||||||
|
})
|
||||||
|
T[k] = J[i]
|
||||||
|
L[i] = k + 1
|
||||||
|
}
|
||||||
|
k := 0
|
||||||
|
for _, v := range L {
|
||||||
|
if k < v {
|
||||||
|
k = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seq := make([]pair, 2+k)
|
||||||
|
seq[1+k] = pair{len(x), len(y)} // sentinel at end
|
||||||
|
lastj := n
|
||||||
|
for i := n - 1; i >= 0; i-- {
|
||||||
|
if L[i] == k && J[i] < lastj {
|
||||||
|
seq[k] = pair{xi[i], yi[J[i]]}
|
||||||
|
k--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seq[0] = pair{0, 0} // sentinel at start
|
||||||
|
return seq
|
||||||
|
}
|
||||||
44
grammar/internal/diff/diff_test.go
Normal file
44
grammar/internal/diff/diff_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright 2022 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/tools/txtar"
|
||||||
|
)
|
||||||
|
|
||||||
|
func clean(text []byte) []byte {
|
||||||
|
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
|
||||||
|
text = bytes.TrimSuffix(text, []byte("^D\n"))
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
files, _ := filepath.Glob("testdata/*.txt")
|
||||||
|
if len(files) == 0 {
|
||||||
|
t.Fatalf("no testdata")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||||
|
a, err := txtar.ParseFile(file)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
|
||||||
|
t.Fatalf("%s: want three files, third named \"diff\"", file)
|
||||||
|
}
|
||||||
|
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
|
||||||
|
want := clean(a.Files[2].Data)
|
||||||
|
if !bytes.Equal(diffs, want) {
|
||||||
|
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
|
||||||
|
diffs, want, Diff("have", diffs, "want", want))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- old --
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -0,0 +1,3 @@
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+c
|
||||||
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- new --
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +0,0 @@
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-c
|
||||||
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
|
||||||
|
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||||
|
|
||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
-- new --
|
||||||
|
w
|
||||||
|
a
|
||||||
|
b
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
||||||
|
e
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,7 +1,7 @@
|
||||||
|
+w
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
-d
|
||||||
|
+x
|
||||||
|
+y
|
||||||
|
+z
|
||||||
|
e
|
||||||
|
-f
|
||||||
|
-g
|
||||||
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
|
||||||
|
b
|
||||||
|
|
||||||
|
c
|
||||||
|
|
||||||
|
d
|
||||||
|
|
||||||
|
e
|
||||||
|
|
||||||
|
f
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
|
||||||
|
B
|
||||||
|
|
||||||
|
C
|
||||||
|
|
||||||
|
d
|
||||||
|
|
||||||
|
e
|
||||||
|
|
||||||
|
f
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,8 +1,8 @@
|
||||||
|
a
|
||||||
|
$
|
||||||
|
-b
|
||||||
|
-
|
||||||
|
-c
|
||||||
|
+B
|
||||||
|
+
|
||||||
|
+C
|
||||||
|
$
|
||||||
|
d
|
||||||
|
$
|
||||||
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
-- old --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
eight
|
||||||
|
nine
|
||||||
|
ten
|
||||||
|
eleven
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -5,7 +5,6 @@
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
-eight
|
||||||
|
-nine
|
||||||
|
-ten
|
||||||
|
-eleven
|
||||||
|
+8
|
||||||
|
+9
|
||||||
|
+10
|
||||||
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- diff --
|
||||||
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +1,3 @@
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
+c
|
||||||
|
\ No newline at end of file
|
||||||
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +1,3 @@
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
\ No newline at end of file
|
||||||
|
+c
|
||||||
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
-- old --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
11
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
14½
|
||||||
|
15
|
||||||
|
16
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
|
20
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
11
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
|
20
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -4,7 +4,6 @@
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
-7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
@@ -12,9 +11,6 @@
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
-14½
|
||||||
|
-15
|
||||||
|
-16
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
-- old --
|
||||||
|
hello world
|
||||||
|
-- new --
|
||||||
|
hello world
|
||||||
|
-- diff --
|
||||||
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
-- old --
|
||||||
|
e
|
||||||
|
pi
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,5 +1,6 @@
|
||||||
|
-e
|
||||||
|
-pi
|
||||||
|
+1
|
||||||
|
+2
|
||||||
|
+3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
Another example from Hunt and McIlroy,
|
||||||
|
“An Algorithm for Differential File Comparison.”
|
||||||
|
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||||
|
|
||||||
|
Anchored diff gives up on finding anything,
|
||||||
|
since there are no unique lines.
|
||||||
|
|
||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
a
|
||||||
|
b
|
||||||
|
b
|
||||||
|
a
|
||||||
|
-- new --
|
||||||
|
c
|
||||||
|
a
|
||||||
|
b
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,7 +1,6 @@
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-c
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-b
|
||||||
|
-a
|
||||||
|
+c
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+c
|
||||||
171
grammar/jsonschema/decode.go
Normal file
171
grammar/jsonschema/decode.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package jsonschema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema holds a JSON schema.
|
||||||
|
type Schema struct {
|
||||||
|
// Name is the name of the property. For the parent/root property, this
|
||||||
|
// is "root". For child properties, this is the name of the property.
|
||||||
|
Name string `json:"-"`
|
||||||
|
|
||||||
|
// Type is the type of the property.
|
||||||
|
//
|
||||||
|
// TODO: Union types (e.g. make this a []string).
|
||||||
|
Type string
|
||||||
|
|
||||||
|
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||||
|
// default, the tuple is "closed." unless Items is set to true or a
|
||||||
|
// valid Schema.
|
||||||
|
PrefixItems []*Schema
|
||||||
|
|
||||||
|
// Items is the schema for each item in a list.
|
||||||
|
//
|
||||||
|
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||||
|
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||||
|
// JSON value is an object, it will be decoded as a Schema.
|
||||||
|
Items *Schema
|
||||||
|
|
||||||
|
// MinItems specifies the minimum number of items allowed in a list.
|
||||||
|
MinItems int
|
||||||
|
|
||||||
|
// MaxItems specifies the maximum number of items allowed in a list.
|
||||||
|
MaxItems int
|
||||||
|
|
||||||
|
// Properties is the schema for each property of an object.
|
||||||
|
Properties []*Schema
|
||||||
|
|
||||||
|
// Format is the format of the property. This is used to validate the
|
||||||
|
// property against a specific format.
|
||||||
|
//
|
||||||
|
// It is the callers responsibility to validate the property against
|
||||||
|
// the format.
|
||||||
|
Format string
|
||||||
|
|
||||||
|
// Minimum specifies the minimum value for numeric properties.
|
||||||
|
Minimum float64
|
||||||
|
|
||||||
|
// Maximum specifies the maximum value for numeric properties.
|
||||||
|
Maximum float64
|
||||||
|
|
||||||
|
// Enum is a list of valid values for the property.
|
||||||
|
Enum []json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||||
|
type S Schema
|
||||||
|
w := struct {
|
||||||
|
Properties props
|
||||||
|
Items items
|
||||||
|
*S
|
||||||
|
}{
|
||||||
|
S: (*S)(s),
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &w); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if w.Items.set {
|
||||||
|
s.Items = &w.Items.Schema
|
||||||
|
}
|
||||||
|
s.Properties = w.Properties
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type items struct {
|
||||||
|
Schema
|
||||||
|
set bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *items) UnmarshalJSON(data []byte) error {
|
||||||
|
switch b := data[0]; b {
|
||||||
|
case 't':
|
||||||
|
*s = items{set: true}
|
||||||
|
case '{':
|
||||||
|
type I items
|
||||||
|
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.set = true
|
||||||
|
case 'n', 'f':
|
||||||
|
default:
|
||||||
|
return errors.New("invalid Items")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||||
|
// not empty, it is returned; otherwise:
|
||||||
|
//
|
||||||
|
// - If the schema has both Properties and Items, it returns an empty string.
|
||||||
|
// - If the schema has Properties, it returns "object".
|
||||||
|
// - If the schema has Items, it returns "array".
|
||||||
|
// - If the schema has neither Properties nor Items, it returns "value".
|
||||||
|
//
|
||||||
|
// The returned string is never empty.
|
||||||
|
func (d *Schema) EffectiveType() string {
|
||||||
|
if d.Type == "" {
|
||||||
|
if len(d.Properties) > 0 {
|
||||||
|
return "object"
|
||||||
|
}
|
||||||
|
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||||
|
return "array"
|
||||||
|
}
|
||||||
|
return "value"
|
||||||
|
}
|
||||||
|
return d.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// props is an ordered list of properties. The order of the properties
|
||||||
|
// is the order in which they were defined in the schema.
|
||||||
|
type props []*Schema
|
||||||
|
|
||||||
|
var _ json.Unmarshaler = (*props)(nil)
|
||||||
|
|
||||||
|
func (v *props) UnmarshalJSON(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] != '{' {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
|
||||||
|
d := json.NewDecoder(bytes.NewReader(data))
|
||||||
|
|
||||||
|
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||||
|
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||||
|
// behavior for clients of this package, since they may not be aware
|
||||||
|
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||||
|
//
|
||||||
|
// For now, just do what llama.cpp does.
|
||||||
|
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t != json.Delim('{') {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
for d.More() {
|
||||||
|
// Use the first token (map key) as the property name, then
|
||||||
|
// decode the rest of the object fields into a Schema and
|
||||||
|
// append.
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t == json.Delim('}') {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := &Schema{
|
||||||
|
Name: t.(string),
|
||||||
|
}
|
||||||
|
if err := d.Decode(s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*v = append(*v, s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
104
grammar/jsonschema/decode_test.go
Normal file
104
grammar/jsonschema/decode_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package jsonschema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testSchemaBasic = `
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"tupleClosedEmpty": { "prefixItems": [] },
|
||||||
|
"tupleClosedMissing": { "prefixItems": [{}] },
|
||||||
|
"tupleClosedNull": { "prefixItems": [{}], "items": null },
|
||||||
|
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
|
||||||
|
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
|
||||||
|
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
|
||||||
|
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
|
||||||
|
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
|
||||||
|
|
||||||
|
"array": { "items": {"type": "number"} },
|
||||||
|
|
||||||
|
"null": { "type": "null" },
|
||||||
|
"string": { "type": "string" },
|
||||||
|
"boolean": { "type": "boolean" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestSchemaUnmarshal(t *testing.T) {
|
||||||
|
var got *Schema
|
||||||
|
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
|
||||||
|
t.Fatalf("Unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
want := &Schema{
|
||||||
|
Properties: []*Schema{
|
||||||
|
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
|
||||||
|
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
|
||||||
|
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||||
|
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||||
|
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
|
||||||
|
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
|
||||||
|
|
||||||
|
{Name: "array", Items: &Schema{Type: "number"}},
|
||||||
|
|
||||||
|
{Name: "null", Type: "null"},
|
||||||
|
{Name: "string", Type: "string"},
|
||||||
|
{Name: "boolean", Type: "boolean"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("(-want, +got)\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveType(t *testing.T) {
|
||||||
|
const schema = `
|
||||||
|
{"properties": {
|
||||||
|
"o": {"type": "object"},
|
||||||
|
"a": {"type": "array"},
|
||||||
|
"n": {"type": "number"},
|
||||||
|
"s": {"type": "string"},
|
||||||
|
"z": {"type": "null"},
|
||||||
|
"b": {"type": "boolean"},
|
||||||
|
|
||||||
|
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
|
||||||
|
"t1": {"items": {"type": "number"}, "maxItems": 3},
|
||||||
|
|
||||||
|
"v": {"maxItems": 3}
|
||||||
|
}}
|
||||||
|
`
|
||||||
|
|
||||||
|
var s *Schema
|
||||||
|
if err := json.Unmarshal([]byte(schema), &s); err != nil {
|
||||||
|
t.Fatalf("json.Unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []string
|
||||||
|
for _, p := range s.Properties {
|
||||||
|
got = append(got, p.EffectiveType())
|
||||||
|
}
|
||||||
|
|
||||||
|
want := strings.Fields(`
|
||||||
|
object
|
||||||
|
array
|
||||||
|
number
|
||||||
|
string
|
||||||
|
null
|
||||||
|
boolean
|
||||||
|
array
|
||||||
|
array
|
||||||
|
value
|
||||||
|
`)
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
76
grammar/testdata/schemas.txt
vendored
Normal file
76
grammar/testdata/schemas.txt
vendored
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# This file holds tests for JSON schema to EBNF grammar conversions.
|
||||||
|
#
|
||||||
|
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
|
||||||
|
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
|
||||||
|
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
|
||||||
|
# name the tests number in the file (e.g. "#0", "#1", etc.)
|
||||||
|
#
|
||||||
|
# Blank lines signify the end or start of a new test. Comments can be added
|
||||||
|
# anywhere in the file, but they must be preceded by a '#' character and start at
|
||||||
|
# the beginning of the line.
|
||||||
|
|
||||||
|
# default
|
||||||
|
{}
|
||||||
|
root ::= value;
|
||||||
|
|
||||||
|
{"properties": {}}
|
||||||
|
root ::= value;
|
||||||
|
|
||||||
|
# array
|
||||||
|
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
|
||||||
|
root_0_tuple_0 ::= string;
|
||||||
|
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||||
|
root ::= "{" "a" ":" root_0 "}";
|
||||||
|
|
||||||
|
# array with nested array
|
||||||
|
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
|
||||||
|
root_tuple_0_tuple_0 ::= string;
|
||||||
|
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
|
||||||
|
root ::= "[" ( root_tuple_0 )* "]";
|
||||||
|
|
||||||
|
# object
|
||||||
|
{"properties": {"e": {}}}
|
||||||
|
root_0 ::= value;
|
||||||
|
root ::= "{" "e" ":" root_0 "}";
|
||||||
|
|
||||||
|
# object with nested object
|
||||||
|
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
|
||||||
|
root_0_0 ::= value;
|
||||||
|
root_0 ::= "{" "e" ":" root_0_0 "}";
|
||||||
|
root ::= "{" "o" ":" root_0 "}";
|
||||||
|
|
||||||
|
# boolean
|
||||||
|
{"type": "boolean"}
|
||||||
|
root ::= boolean;
|
||||||
|
|
||||||
|
# number
|
||||||
|
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
|
||||||
|
root_0 ::= number;
|
||||||
|
root ::= "{" "n" ":" root_0 "}";
|
||||||
|
|
||||||
|
# string
|
||||||
|
{"type": "string"}
|
||||||
|
root ::= string;
|
||||||
|
|
||||||
|
# string with enum
|
||||||
|
{"type": "string", "enum": ["a", "b", "c"]}
|
||||||
|
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
|
||||||
|
|
||||||
|
# spaces in key
|
||||||
|
{"properties": {"a b": {}}}
|
||||||
|
root_0 ::= value;
|
||||||
|
root ::= "{" "a b" ":" root_0 "}";
|
||||||
|
|
||||||
|
# issue7978
|
||||||
|
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
|
||||||
|
root_0_tuple_0_0 ::= string;
|
||||||
|
root_0_tuple_0_1 ::= string;
|
||||||
|
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
|
||||||
|
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||||
|
root_1 ::= string;
|
||||||
|
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
|
||||||
|
|
||||||
|
# !! # special characters in key
|
||||||
|
# !! {"properties": {"a!b": {}}}
|
||||||
|
# !! !invalid character '!' in key
|
||||||
|
# !!
|
||||||
@@ -43,8 +43,13 @@ type Cache interface {
|
|||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters.
|
||||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
// Close closes the cache and frees resources associated with it
|
// Close closes the cache and frees resources associated with it
|
||||||
Close()
|
Close()
|
||||||
@@ -52,7 +57,7 @@ type Cache interface {
|
|||||||
// StartForward is called before the start of the model's forward pass.
|
// StartForward is called before the start of the model's forward pass.
|
||||||
// For each token in the coming batch, there must be a corresponding
|
// For each token in the coming batch, there must be a corresponding
|
||||||
// entry in positions and seqs.
|
// entry in positions and seqs.
|
||||||
StartForward(ctx ml.Context, opts input.Options) error
|
StartForward(ctx ml.Context, batch input.Batch) error
|
||||||
|
|
||||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
// The mask is of shape history size, batch size
|
// The mask is of shape history size, batch size
|
||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
||||||
c.cells = make([]cacheCell, c.Capacity)
|
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
}
|
}
|
||||||
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(batch.Positions)
|
||||||
c.curSequences = opts.Sequences
|
c.curSequences = batch.Sequences
|
||||||
c.curPositions = opts.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
if errors.Is(err, ErrKvCacheFull) {
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
c.curCellRange = newRange()
|
||||||
for i, pos := range opts.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := opts.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundDown(length, pad int) int {
|
func roundDown(length, pad int) int {
|
||||||
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
return maskTensor, nil
|
return maskTensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
continue
|
continue
|
||||||
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
value := c.values[i]
|
value := c.values[i]
|
||||||
var vSrcView, vDstView ml.Tensor
|
var vSrcView, vDstView ml.Tensor
|
||||||
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
vHeadDim := value.Dim(1)
|
vHeadDim := value.Dim(1)
|
||||||
elemSize := value.Stride(0)
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
vHeadDim := value.Dim(0)
|
vHeadDim := value.Dim(0)
|
||||||
rowSize := value.Stride(2)
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
|
|||||||
ctx := c.backend.NewContext()
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
// For every move, 6 tensors are required per layer (2 views and a
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
// copy for each of k and v).
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
layers := 0
|
layers := 0
|
||||||
for _, key := range c.keys {
|
for _, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
elemSize := c.values[c.curLayer].Stride(0)
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
} else {
|
} else {
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
|||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF32, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "SlidingWindow",
|
name: "FirstBatch",
|
||||||
in: []float32{1, 2, 3, 4},
|
in: []float32{1, 2, 3, 4},
|
||||||
inShape: []int{1, 1, 4},
|
inShape: []int{1, 1, 4},
|
||||||
seqs: []int{0, 0, 0, 0},
|
seqs: []int{0, 0, 0, 0},
|
||||||
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
|||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
|
|||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
context := backend.NewContext()
|
context := backend.NewContext()
|
||||||
defer context.Close()
|
defer context.Close()
|
||||||
|
|
||||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
c.config = &config
|
c.config = &config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
}
|
}
|
||||||
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
// We work with the most recent image
|
// We work with the most recent image
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Init(backend, dtype, capacity)
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
for i, cache := range c.caches {
|
for i, cache := range c.caches {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
for j := i - 1; j >= 0; j-- {
|
for j := i - 1; j >= 0; j-- {
|
||||||
for k := range opts.Positions {
|
for k := range batch.Positions {
|
||||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/grammar"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
@@ -700,9 +701,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// User provided a JSON schema
|
// User provided a JSON schema
|
||||||
g := llama.SchemaToGrammar(req.Format)
|
g, err := grammar.FromSchema(nil, req.Format)
|
||||||
if g == nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid JSON schema in format")
|
return fmt.Errorf("invalid JSON schema in format: %w", err)
|
||||||
}
|
}
|
||||||
req.Grammar = string(g)
|
req.Grammar = string(g)
|
||||||
}
|
}
|
||||||
@@ -713,6 +714,11 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
req.Options = &opts
|
req.Options = &opts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options == nil {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
req.Options = &opts
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
@@ -727,7 +733,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||||
req.Options.NumPredict = 10 * s.options.NumCtx
|
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatusRetry(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -60,6 +61,10 @@ type CacheConfig struct {
|
|||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
|
// Progress is a callback function that allows reporting percentage completion
|
||||||
|
// of model loading
|
||||||
|
Progress func(float32)
|
||||||
|
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
||||||
@@ -76,9 +81,9 @@ type BackendParams struct {
|
|||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
||||||
|
|
||||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
||||||
if _, ok := backends[name]; ok {
|
if _, ok := backends[name]; ok {
|
||||||
panic("backend: backend already registered")
|
panic("backend: backend already registered")
|
||||||
}
|
}
|
||||||
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
|||||||
backends[name] = f
|
backends[name] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
||||||
if backend, ok := backends["ggml"]; ok {
|
if backend, ok := backends["ggml"]; ok {
|
||||||
return backend(f, params)
|
return backend(ctx, f, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
|
|||||||
@@ -9,15 +9,17 @@ package ggml
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -58,7 +60,7 @@ type Backend struct {
|
|||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||||
meta, n, err := fs.Decode(r, -1)
|
meta, n, err := fs.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
var doneBytes atomic.Uint64
|
||||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
totalBytes := uint64(n) - meta.Tensors().Offset
|
||||||
var g errgroup.Group
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
for _, t := range meta.Tensors().Items() {
|
for _, t := range meta.Tensors().Items() {
|
||||||
for _, target := range targets[t.Name] {
|
g.Go(func() error {
|
||||||
g.Go(func() error {
|
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
||||||
|
for i := range tts {
|
||||||
|
target := targets[t.Name][i]
|
||||||
if target == "" {
|
if target == "" {
|
||||||
target = t.Name
|
target = t.Name
|
||||||
}
|
}
|
||||||
@@ -312,25 +318,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := C.malloc(C.size_t(t.Size()))
|
tts[i] = tt
|
||||||
if bts == nil {
|
}
|
||||||
return errors.New("failed to allocate tensor buffer")
|
|
||||||
}
|
|
||||||
defer C.free(bts)
|
|
||||||
|
|
||||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
bts := make([]byte, 128*format.KibiByte)
|
||||||
if err != nil || n != len(buf) {
|
|
||||||
return errors.New("read failed")
|
var s uint64
|
||||||
|
for s < t.Size() {
|
||||||
|
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
for _, tt := range tts {
|
||||||
return nil
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||||
})
|
}
|
||||||
}
|
|
||||||
|
s += uint64(n)
|
||||||
|
|
||||||
|
if params.Progress != nil {
|
||||||
|
done := doneBytes.Add(uint64(n))
|
||||||
|
params.Progress(float32(done) / float32(totalBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.Wait() != nil {
|
// start a goroutine to cancel the errgroup if the parent context is done
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
g.Go(func() error {
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package input
|
package input
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
||||||
// Input represents one token in the input stream
|
// Input represents one token in the input stream
|
||||||
type Input struct {
|
type Input struct {
|
||||||
// Token is a single element of text.
|
// Token is a single element of text.
|
||||||
@@ -33,11 +35,24 @@ type MultimodalIndex struct {
|
|||||||
Multimodal any
|
Multimodal any
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options contains the inputs for a model forward pass
|
// Batch contains the inputs for a model forward pass
|
||||||
type Options struct {
|
type Batch struct {
|
||||||
Inputs []int32
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
|
Inputs ml.Tensor
|
||||||
|
|
||||||
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
|
// models or for batches without multimodal elements.
|
||||||
Multimodal []MultimodalIndex
|
Multimodal []MultimodalIndex
|
||||||
Positions []int32
|
|
||||||
Sequences []int
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
Outputs []int32
|
// in length to Inputs.
|
||||||
|
Positions []int32
|
||||||
|
|
||||||
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
|
Sequences []int
|
||||||
|
|
||||||
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
|
// be returned.
|
||||||
|
Outputs []int32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
|||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||||
|
|
||||||
Backend() ml.Backend
|
Backend() ml.Backend
|
||||||
Config() config
|
Config() config
|
||||||
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
r, err := os.Open(modelPath)
|
r, err := os.Open(modelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
b, err := ml.NewBackend(r, params)
|
b, err := ml.NewBackend(ctx, r, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(opts.Positions) != len(opts.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.Positions) < 1 {
|
if len(batch.Positions) < 1 {
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, opts)
|
t, err := m.Forward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
|
|||||||
|
|
||||||
type notTextProcessorModel struct{}
|
type notTextProcessorModel struct{}
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
if len(m.Layers) == gemma27BLayerCount {
|
||||||
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
for _, image := range opts.Multimodal {
|
for _, image := range batch.Multimodal {
|
||||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|||||||
@@ -135,32 +135,27 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return inputs, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
crossAttentionStates = images[len(images)-1]
|
crossAttentionStates = images[len(images)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type TextProcessor interface {
|
|||||||
Encode(s string, addSpecial bool) ([]int32, error)
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(int32, Special) bool
|
Is(int32, Special) bool
|
||||||
|
Vocab() *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
|
|||||||
@@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Vocab() *Vocabulary {
|
||||||
|
return spm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||||
return func(yield func(string) bool) {
|
return func(yield func(string) bool) {
|
||||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ type InputCache struct {
|
|||||||
cache kvcache.Cache
|
cache kvcache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||||
if kvSize/int32(numSlots) < 1 {
|
numCtx := kvSize / int32(numSlots)
|
||||||
|
|
||||||
|
if numCtx < 1 {
|
||||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
|
|
||||||
cache := model.Config().Cache
|
cache := model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &InputCache{
|
return &InputCache{
|
||||||
numCtx: kvSize / int32(numSlots),
|
numCtx: numCtx,
|
||||||
enabled: cache != nil,
|
enabled: cache != nil,
|
||||||
slots: slots,
|
slots: slots,
|
||||||
multiUserCache: multiUserCache,
|
multiUserCache: multiUserCache,
|
||||||
|
|||||||
@@ -348,7 +348,8 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var options input.Options
|
var batchInputs []int32
|
||||||
|
var batch input.Batch
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
@@ -395,17 +396,17 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
batchInputs = append(batchInputs, inp.Token)
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(options.Outputs)
|
seq.iBatch = len(batch.Outputs)
|
||||||
if j+1 == len(seq.inputs) {
|
if j+1 == len(seq.inputs) {
|
||||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
@@ -413,14 +414,14 @@ func (s *Server) processBatch() error {
|
|||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.Inputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
ctx := s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
return fmt.Errorf("failed to decode batch: %w", err)
|
||||||
}
|
}
|
||||||
@@ -460,13 +461,27 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(options.Outputs)
|
vocabSize := len(logits) / len(batch.Outputs)
|
||||||
|
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sample token: %w", err)
|
return fmt.Errorf("failed to sample token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if seq.sampler.JSONSampler != nil {
|
||||||
|
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seq.sampler.PythonSampler != nil {
|
||||||
|
err = seq.sampler.PythonSampler.UpdateState(token)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
@@ -561,6 +576,21 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
|
||||||
|
// if err != nil {
|
||||||
|
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// jsonSampler = nil
|
||||||
|
pythonSampler := &sample.PythonSampler{}
|
||||||
|
functions := []sample.PythonFunction{
|
||||||
|
{
|
||||||
|
Name: "add_two_strings",
|
||||||
|
Args: []string{"s1", "s2"},
|
||||||
|
Types: []string{"string", "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
req.Options.Temperature,
|
req.Options.Temperature,
|
||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
@@ -568,6 +598,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
|
nil,
|
||||||
|
pythonSampler,
|
||||||
|
// nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
@@ -677,6 +710,7 @@ func (m *multiLPath) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
|
ctx context.Context,
|
||||||
mpath string,
|
mpath string,
|
||||||
params ml.BackendParams,
|
params ml.BackendParams,
|
||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
@@ -686,7 +720,7 @@ func (s *Server) loadModel(
|
|||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) {
|
||||||
var err error
|
var err error
|
||||||
s.model, err = model.New(mpath, params)
|
s.model, err = model.New(ctx, mpath, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -698,7 +732,7 @@ func (s *Server) loadModel(
|
|||||||
panic("loras are not yet implemented")
|
panic("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -782,6 +816,9 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
|
Progress: func(progress float32) {
|
||||||
|
server.progress = progress
|
||||||
|
},
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
@@ -790,13 +827,13 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
go server.run(ctx)
|
go server.run(ctx)
|
||||||
|
|
||||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
|||||||
53
sample/gtf.go
Normal file
53
sample/gtf.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
var DefaultGrammar = map[string]string{
|
||||||
|
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
|
||||||
|
"null": `"null"`,
|
||||||
|
"object": `"{" (kv ("," kv)*)? "}"`,
|
||||||
|
"array": `"[" (value ("," value)*)? "]"`,
|
||||||
|
"kv": `string ":" value`,
|
||||||
|
"integer": `"0" | [1-9] [0-9]*`,
|
||||||
|
"number": `"-"? integer frac? exp?`,
|
||||||
|
"frac": `"." [0-9]+`,
|
||||||
|
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
|
||||||
|
"string": `"\"" char* "\""`,
|
||||||
|
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
|
||||||
|
"char": `[^"\\] | escape`,
|
||||||
|
"space": `(" " | "\t" | "\n" | "\r")*`,
|
||||||
|
"hex": `[0-9] | [a-f] | [A-F]`,
|
||||||
|
"boolean": `"true" | "false"`,
|
||||||
|
"value": `object | array | string | number | boolean | "null"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
const jsonString = `object | array`
|
||||||
|
|
||||||
|
type StateMachine struct {
|
||||||
|
states map[rune]State
|
||||||
|
}
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
NextStates []string
|
||||||
|
// bitmask?
|
||||||
|
Mask []bool
|
||||||
|
IsTerminal bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
|
||||||
|
states := make(map[rune]State)
|
||||||
|
|
||||||
|
var cumu string
|
||||||
|
flag := false
|
||||||
|
for _, r := range startRule {
|
||||||
|
if r == '"' {
|
||||||
|
flag = !flag
|
||||||
|
}
|
||||||
|
if flag {
|
||||||
|
cumu += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sm := &StateMachine{
|
||||||
|
states: states,
|
||||||
|
}
|
||||||
|
return sm
|
||||||
|
}
|
||||||
138
sample/gtf_test.go
Normal file
138
sample/gtf_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGrammarParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grammar map[string]string
|
||||||
|
startRule string
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple object",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"object": `"{" "}"`,
|
||||||
|
},
|
||||||
|
startRule: "object",
|
||||||
|
input: "{}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple array",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"array": `"[" "]"`,
|
||||||
|
},
|
||||||
|
startRule: "array",
|
||||||
|
input: "[]",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "character class",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"digit": `[0-9]`,
|
||||||
|
},
|
||||||
|
startRule: "digit",
|
||||||
|
input: "5",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alternation",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"bool": `"true" | "false"`,
|
||||||
|
},
|
||||||
|
startRule: "bool",
|
||||||
|
input: "true",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repetition",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"digits": `[0-9]+`,
|
||||||
|
},
|
||||||
|
startRule: "digits",
|
||||||
|
input: "123",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested rules",
|
||||||
|
grammar: map[string]string{
|
||||||
|
"value": `object | array`,
|
||||||
|
"object": `"{" "}"`,
|
||||||
|
"array": `"[" "]"`,
|
||||||
|
},
|
||||||
|
startRule: "value",
|
||||||
|
input: "{}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := NewParser(tt.grammar)
|
||||||
|
machine, err := parser.Parse(tt.startRule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matcher := NewMatcher(machine)
|
||||||
|
got, err := matcher.Match(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Match() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONGrammar(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty object", "{}", true},
|
||||||
|
{"empty array", "[]", true},
|
||||||
|
{"simple string", `"hello"`, true},
|
||||||
|
{"simple number", "123", true},
|
||||||
|
{"simple boolean", "true", true},
|
||||||
|
{"simple null", "null", true},
|
||||||
|
{"object with string", `{"key": "value"}`, true},
|
||||||
|
{"array with numbers", "[1, 2, 3]", true},
|
||||||
|
{"nested object", `{"obj": {"key": "value"}}`, true},
|
||||||
|
{"nested array", `[1, [2, 3], 4]`, true},
|
||||||
|
{"invalid object", "{", false},
|
||||||
|
{"invalid array", "[1, 2", false},
|
||||||
|
{"invalid string", `"hello`, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := NewParser(DefaultGrammar)
|
||||||
|
machine, err := parser.Parse("value")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matcher := NewMatcher(machine)
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := matcher.Match(tt.input)
|
||||||
|
if tt.want {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Match() error = %v", err)
|
||||||
|
}
|
||||||
|
if !got {
|
||||||
|
t.Errorf("Match() = false, want true")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil && got {
|
||||||
|
t.Errorf("Match() = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
160
sample/json_types.go
Normal file
160
sample/json_types.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateStart JSONState = iota
|
||||||
|
StateInObject
|
||||||
|
StateInObjectKey
|
||||||
|
StateInStructuredKey
|
||||||
|
StateInStructuredValue
|
||||||
|
StateNewline
|
||||||
|
StateTab
|
||||||
|
StateSpace
|
||||||
|
StateInString
|
||||||
|
StateInInt
|
||||||
|
StateInFloat
|
||||||
|
StateInBool
|
||||||
|
StateInNull
|
||||||
|
StateInColon
|
||||||
|
StateInComma
|
||||||
|
StateInTab
|
||||||
|
StateInSpaceToValue
|
||||||
|
StateInSpaceEndValue
|
||||||
|
StateInNewlineEndValue
|
||||||
|
StateInObjSpace
|
||||||
|
StateInList
|
||||||
|
StateInListComma
|
||||||
|
StateInValue
|
||||||
|
StateInValueEnd
|
||||||
|
StateInListEnd
|
||||||
|
StateInListObjectEnd
|
||||||
|
StateInNewline
|
||||||
|
StateInNumber
|
||||||
|
StateInNumberEnd
|
||||||
|
StateInStringEnd
|
||||||
|
StateInObjectKeyEnd
|
||||||
|
StateTerminate
|
||||||
|
StateInObjectEnd
|
||||||
|
StateTransitioningToTerminate
|
||||||
|
StateInListStartJSON
|
||||||
|
)
|
||||||
|
|
||||||
|
var JSONStates = []JSONState{
|
||||||
|
StateStart,
|
||||||
|
StateInObject,
|
||||||
|
StateInObjectKey,
|
||||||
|
StateInStructuredKey,
|
||||||
|
StateInStructuredValue,
|
||||||
|
StateNewline,
|
||||||
|
StateTab,
|
||||||
|
StateSpace,
|
||||||
|
StateInString,
|
||||||
|
StateInInt,
|
||||||
|
StateInFloat,
|
||||||
|
StateInBool,
|
||||||
|
StateInNull,
|
||||||
|
StateInColon,
|
||||||
|
StateInComma,
|
||||||
|
StateInTab,
|
||||||
|
StateInSpaceToValue,
|
||||||
|
StateInSpaceEndValue,
|
||||||
|
StateInNewlineEndValue,
|
||||||
|
StateInObjSpace,
|
||||||
|
StateInListStartJSON,
|
||||||
|
StateInList,
|
||||||
|
StateInListComma,
|
||||||
|
StateInValue,
|
||||||
|
StateInValueEnd,
|
||||||
|
StateInListEnd,
|
||||||
|
StateInListObjectEnd,
|
||||||
|
StateInNewline,
|
||||||
|
StateInNumber,
|
||||||
|
StateInNumberEnd,
|
||||||
|
StateInStringEnd,
|
||||||
|
StateInObjectKeyEnd,
|
||||||
|
StateTerminate,
|
||||||
|
StateInObjectEnd,
|
||||||
|
StateTransitioningToTerminate,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONState) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateStart:
|
||||||
|
return "StateStart"
|
||||||
|
case StateInObject:
|
||||||
|
return "StateInObject"
|
||||||
|
case StateInObjectKey:
|
||||||
|
return "StateInObjectKey"
|
||||||
|
case StateInStructuredKey:
|
||||||
|
return "StateInStructuredKey"
|
||||||
|
case StateInStructuredValue:
|
||||||
|
return "StateInStructuredValue"
|
||||||
|
case StateNewline:
|
||||||
|
return "StateNewline"
|
||||||
|
case StateTab:
|
||||||
|
return "StateTab"
|
||||||
|
case StateSpace:
|
||||||
|
return "StateSpace"
|
||||||
|
case StateInString:
|
||||||
|
return "StateInString"
|
||||||
|
case StateInInt:
|
||||||
|
return "StateInInt"
|
||||||
|
case StateInFloat:
|
||||||
|
return "StateInFloat"
|
||||||
|
case StateInBool:
|
||||||
|
return "StateInBool"
|
||||||
|
case StateInNull:
|
||||||
|
return "StateInNull"
|
||||||
|
case StateInColon:
|
||||||
|
return "StateInColon"
|
||||||
|
case StateInComma:
|
||||||
|
return "StateInComma"
|
||||||
|
case StateInTab:
|
||||||
|
return "StateInTab"
|
||||||
|
case StateInSpaceToValue:
|
||||||
|
return "StateInSpaceToValue"
|
||||||
|
case StateInSpaceEndValue:
|
||||||
|
return "StateInSpaceEndValue"
|
||||||
|
case StateInNewlineEndValue:
|
||||||
|
return "StateInNewlineEndValue"
|
||||||
|
case StateInObjSpace:
|
||||||
|
return "StateInObjSpace"
|
||||||
|
case StateInList:
|
||||||
|
return "StateInList"
|
||||||
|
case StateInListComma:
|
||||||
|
return "StateInListComma"
|
||||||
|
case StateInValue:
|
||||||
|
return "StateInValue"
|
||||||
|
case StateInValueEnd:
|
||||||
|
return "StateInValueEnd"
|
||||||
|
case StateInListEnd:
|
||||||
|
return "StateInListEnd"
|
||||||
|
case StateInListObjectEnd:
|
||||||
|
return "StateInListObjectEnd"
|
||||||
|
case StateInNewline:
|
||||||
|
return "StateInNewline"
|
||||||
|
case StateInNumber:
|
||||||
|
return "StateInNumber"
|
||||||
|
case StateInNumberEnd:
|
||||||
|
return "StateInNumberEnd"
|
||||||
|
case StateInStringEnd:
|
||||||
|
return "StateInStringEnd"
|
||||||
|
case StateInObjectKeyEnd:
|
||||||
|
return "StateInObjectKeyEnd"
|
||||||
|
case StateTerminate:
|
||||||
|
return "StateTerminate"
|
||||||
|
case StateInObjectEnd:
|
||||||
|
return "StateInObjectEnd"
|
||||||
|
case StateTransitioningToTerminate:
|
||||||
|
return "StateTransitioningToTerminate"
|
||||||
|
case StateInListStartJSON:
|
||||||
|
return "StateInListStartJSON"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Unknown state: %d", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
327
sample/pushdown_automata.go
Normal file
327
sample/pushdown_automata.go
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
Key JSON rules to consider:
|
||||||
|
|
||||||
|
1. Whitespace handling:
|
||||||
|
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
||||||
|
- Current code only handles some whitespace cases
|
||||||
|
|
||||||
|
2. Number validation:
|
||||||
|
- Need proper validation for special number cases like -0
|
||||||
|
- Should handle .5 style decimals
|
||||||
|
- Need limits on scientific notation (e, E)
|
||||||
|
|
||||||
|
3. String escaping:
|
||||||
|
- Currently marks \ as invalid but should allow escaped sequences:
|
||||||
|
- \"
|
||||||
|
- \n
|
||||||
|
- \u1234 unicode escapes
|
||||||
|
|
||||||
|
4. Empty object/array transitions:
|
||||||
|
- Direct {} and [] cases could be more explicit
|
||||||
|
- Need clear transitions for these edge cases
|
||||||
|
|
||||||
|
5. Nested depth limits:
|
||||||
|
- No protection against excessive nesting
|
||||||
|
- Could cause stack overflow with deeply nested structures
|
||||||
|
*/
|
||||||
|
|
||||||
|
// TODO: / should be valid but an escape character
|
||||||
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||||
|
|
||||||
|
var (
|
||||||
|
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
|
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
|
)
|
||||||
|
|
||||||
|
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||||
|
|
||||||
|
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||||
|
|
||||||
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||||
|
|
||||||
|
type PDA struct {
|
||||||
|
State JSONState
|
||||||
|
TransitionEdges map[rune]*PDA
|
||||||
|
MaskTokenIDToNode map[int32]*PDA
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPDANode(state JSONState) *PDA {
|
||||||
|
return &PDA{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PDAGraphBuilder struct {
|
||||||
|
proc model.TextProcessor
|
||||||
|
decodedToks []string
|
||||||
|
stateToNodeMap map[JSONState]*PDA
|
||||||
|
tokenToStatesMap map[int32][]JSONState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||||
|
stateToNodeMap := make(map[JSONState]*PDA)
|
||||||
|
for _, state := range JSONStates {
|
||||||
|
stateToNodeMap[state] = NewPDANode(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
|
// TODO: update naming here - and revisit values
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
|
// new line
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
|
||||||
|
// new line end value
|
||||||
|
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
// TODO: see if this is needed for formatting
|
||||||
|
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
|
||||||
|
// where values should be
|
||||||
|
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||||
|
|
||||||
|
// Leads to a value
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
// Values
|
||||||
|
// string node
|
||||||
|
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
|
||||||
|
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||||
|
|
||||||
|
// String end node
|
||||||
|
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
|
// number node
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// list node
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
// early end
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
|
// list end node
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// empty list
|
||||||
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||||
|
|
||||||
|
// null node
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// list comma
|
||||||
|
// should point to values
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||||
|
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
|
||||||
|
|
||||||
|
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||||
|
|
||||||
|
// list object end
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
// TODO: not sure if this is needed
|
||||||
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// bool node
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||||
|
}
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||||
|
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
// comma node
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
|
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||||
|
// todo: review this space transition
|
||||||
|
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||||
|
|
||||||
|
// space end value
|
||||||
|
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||||
|
|
||||||
|
b.stateToNodeMap = stateToNodeMap
|
||||||
|
if err := b.preComputeValidStates(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
|
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
}
|
||||||
|
|
||||||
|
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||||
|
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||||
|
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||||
|
for _, node := range b.stateToNodeMap {
|
||||||
|
// if node.State == StateInObjectKey {
|
||||||
|
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
|
||||||
|
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
|
||||||
|
// fmt.Println("copying string mask to object key mask")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
if err := b.CreateMask(node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
|
||||||
|
// TODO: make can be somewhere else too
|
||||||
|
b.tokenToStatesMap = make(map[int32][]JSONState)
|
||||||
|
for i, t := range b.decodedToks {
|
||||||
|
for _, r := range t {
|
||||||
|
if r == '"' {
|
||||||
|
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: the mask for obj key and string should be the same?
|
||||||
|
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
}
|
||||||
|
for i := range b.decodedToks {
|
||||||
|
token := b.decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
curNode := node
|
||||||
|
valid := true
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if curNode == nil || !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
// fmt.Println("next node", nextNode)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
264
sample/pushdown_runner.go
Normal file
264
sample/pushdown_runner.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: safety in case of invalid json
|
||||||
|
// TODO: partial JSON matching?
|
||||||
|
// TODO: interfaces to cleanup with return values
|
||||||
|
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||||
|
// TODO: add penalties for string \n stuff
|
||||||
|
// TODO: minimize number of fwd passes if there is only one match
|
||||||
|
// TODO: greedy sample initially and then backtrack if no match
|
||||||
|
|
||||||
|
type PushdownSampler struct {
|
||||||
|
PDAGraphBuilder
|
||||||
|
curNode *PDA
|
||||||
|
braceStack []rune
|
||||||
|
stateCounter uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// graph should be built once and reused per tokenizer
|
||||||
|
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
fmt.Println("PDA sampler")
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||||
|
|
||||||
|
vocab := proc.Vocab()
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
gb := &PDAGraphBuilder{
|
||||||
|
proc: proc,
|
||||||
|
decodedToks: decodedToks,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gb.BuildGraph(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
after := m.Alloc
|
||||||
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||||
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||||
|
|
||||||
|
// TODO: this can be simplified
|
||||||
|
return &PushdownSampler{
|
||||||
|
curNode: gb.stateToNodeMap[StateStart],
|
||||||
|
PDAGraphBuilder: *gb,
|
||||||
|
braceStack: []rune{},
|
||||||
|
stateCounter: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: need to add resampling logic if the first sample was not good
|
||||||
|
// greedy sample + backtrack?
|
||||||
|
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
|
||||||
|
switch s.curNode.State {
|
||||||
|
case StateInString:
|
||||||
|
return s.maskLogits(logits, s.curNode)
|
||||||
|
|
||||||
|
case StateInListEnd:
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
case StateTerminate:
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
|
||||||
|
case StateInObjectEnd:
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
return forceFinish(s, logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
case StateInComma:
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
fmt.Println("masking logits current state", s.curNode.State)
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
|
fmt.Println("current state - updating", s.curNode.State)
|
||||||
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||||
|
|
||||||
|
// Special handling for EOS token in terminate state
|
||||||
|
if s.curNode.State == StateTerminate {
|
||||||
|
for _, tokenID := range tokenSlice {
|
||||||
|
if s.proc.Is(tokenID, model.SpecialEOS) {
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flag := -1
|
||||||
|
// endBraceRunes := []rune{'}', ']'}
|
||||||
|
for _, r := range mappedString {
|
||||||
|
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
||||||
|
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
||||||
|
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
||||||
|
// // flag = i
|
||||||
|
// break
|
||||||
|
|
||||||
|
// }
|
||||||
|
if r == rune('{') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
|
if r == rune('[') {
|
||||||
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
}
|
||||||
|
if r == rune('}') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if top != rune('{') {
|
||||||
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == rune(']') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if top != rune('[') {
|
||||||
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
|
}
|
||||||
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if flag != -1 {
|
||||||
|
// tokenSlice = tokenSlice[:flag]
|
||||||
|
// }
|
||||||
|
// fmt.Println("flag!", flag)
|
||||||
|
for _, tokenID := range tokenSlice {
|
||||||
|
// transition to the next node
|
||||||
|
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
||||||
|
}
|
||||||
|
fmt.Println("transitioning to", nextNode.State)
|
||||||
|
|
||||||
|
// TODO: add a penalty for staying in the same state too long
|
||||||
|
if nextNode.State == s.curNode.State {
|
||||||
|
s.stateCounter++
|
||||||
|
} else {
|
||||||
|
s.stateCounter = 0
|
||||||
|
}
|
||||||
|
s.curNode = nextNode
|
||||||
|
fmt.Println("updated curNode state", s.curNode.State)
|
||||||
|
}
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// greedy sample + backtrack?
|
||||||
|
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||||
|
// Create a new slice with same length as logits, initialized to -Inf
|
||||||
|
maskedLogits := make([]float32, len(logits))
|
||||||
|
for i := range maskedLogits {
|
||||||
|
maskedLogits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update values for valid token IDs from the mask map
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) {
|
||||||
|
maskedLogits[tokenID] = logits[tokenID]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskedLogits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||||
|
maxLogit := float32(math.Inf(-1))
|
||||||
|
maxIndex := -1
|
||||||
|
|
||||||
|
// Find the maximum logit value among valid tokens
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
||||||
|
maxLogit = logits[tokenID]
|
||||||
|
maxIndex = int(tokenID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxIndex == -1 {
|
||||||
|
return nil, fmt.Errorf("no valid tokens found in mask")
|
||||||
|
}
|
||||||
|
|
||||||
|
logits[0] = float32(maxIndex)
|
||||||
|
return logits, nil
|
||||||
|
// return maxIndex, nil
|
||||||
|
}
|
||||||
@@ -17,15 +17,34 @@ type token struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
topK int
|
topK int
|
||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
grammar *Grammar
|
grammar *Grammar
|
||||||
|
JSONSampler *JSONSampler
|
||||||
|
PythonSampler *PythonSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if s.JSONSampler != nil {
|
||||||
|
logits, err = s.JSONSampler.Apply(logits)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.PythonSampler != nil {
|
||||||
|
logits, err = s.PythonSampler.ApplyMask(logits)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@@ -94,13 +113,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
var r float32
|
||||||
if s.rng != nil {
|
if s.rng != nil {
|
||||||
r = s.rng.Float32()
|
r = s.rng.Float32()
|
||||||
@@ -123,11 +135,14 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return 1
|
return 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if math.IsNaN(float64(sum)) {
|
||||||
|
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||||
|
}
|
||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
@@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
topP: topP,
|
topP: topP,
|
||||||
minP: minP,
|
minP: minP,
|
||||||
temperature: temperature,
|
temperature: temperature,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
|
JSONSampler: jsonSampler,
|
||||||
|
PythonSampler: pythonSampler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
|||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test very high p
|
||||||
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
|
// Use extremely small topP to filter out all tokens
|
||||||
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Should get the token with the highest logit
|
||||||
|
want = int32(0)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
|
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got %d", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
|||||||
299
sample/structured_outputs.go
Normal file
299
sample/structured_outputs.go
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/jsonschema"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONSampler struct {
|
||||||
|
schema *jsonschema.Schema
|
||||||
|
propIdx int
|
||||||
|
propToNodeMap map[string]*PDA
|
||||||
|
pdaSampler *PushdownSampler
|
||||||
|
decodedToks []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
|
||||||
|
slog.Info("NewJSONSampler", "schema", schema)
|
||||||
|
if proc == nil {
|
||||||
|
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
pdaSampler, err := NewPushdownSampler(proc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema == nil {
|
||||||
|
return &JSONSampler{
|
||||||
|
schema: nil,
|
||||||
|
propIdx: -1,
|
||||||
|
propToNodeMap: nil,
|
||||||
|
pdaSampler: pdaSampler,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Println("schema not nil")
|
||||||
|
so := &JSONSampler{
|
||||||
|
schema: schema,
|
||||||
|
propIdx: -1,
|
||||||
|
propToNodeMap: make(map[string]*PDA),
|
||||||
|
pdaSampler: pdaSampler,
|
||||||
|
}
|
||||||
|
|
||||||
|
so.schemaToGraph()
|
||||||
|
|
||||||
|
// Benchmark token decoding
|
||||||
|
start := time.Now()
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before := m.Alloc
|
||||||
|
|
||||||
|
vocab := proc.Vocab()
|
||||||
|
decodedToks := make([]string, len(vocab.Values))
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
so.decodedToks = decodedToks
|
||||||
|
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
after := m.Alloc
|
||||||
|
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
||||||
|
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
fmt.Println("SOSampler")
|
||||||
|
fmt.Println("--------------------------------")
|
||||||
|
// Benchmark this section
|
||||||
|
start = time.Now()
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
before = m.Alloc
|
||||||
|
|
||||||
|
// TODO: still messed up
|
||||||
|
// TODO: recursion use case
|
||||||
|
// key masks
|
||||||
|
for _, prop := range so.schema.Properties {
|
||||||
|
node := so.propToNodeMap[prop.Name]
|
||||||
|
// propName -> node
|
||||||
|
curState := node.State
|
||||||
|
fromNode := node
|
||||||
|
so.pdaSampler.CreateMask(fromNode)
|
||||||
|
for curState == StateInStructuredKey {
|
||||||
|
// there is only one edge
|
||||||
|
for r, toNode := range fromNode.TransitionEdges {
|
||||||
|
fmt.Println("rune", r, "edge", toNode.State)
|
||||||
|
so.pdaSampler.CreateMask(toNode)
|
||||||
|
fmt.Printf("created mask for %c\n", r)
|
||||||
|
curState = toNode.State
|
||||||
|
fmt.Println("next state", curState)
|
||||||
|
// TODO: theres an extra gen for " right now
|
||||||
|
fromNode = toNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if curState != StateInColon {
|
||||||
|
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// so.pdaSampler.CreateMask(fromNode)
|
||||||
|
|
||||||
|
fromNode = fromNode.TransitionEdges[' ']
|
||||||
|
|
||||||
|
so.pdaSampler.CreateMask(fromNode)
|
||||||
|
curState = fromNode.State
|
||||||
|
for _, toNode := range fromNode.TransitionEdges {
|
||||||
|
fmt.Println("toNode", toNode.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime.ReadMemStats(&m)
|
||||||
|
// after = m.Alloc
|
||||||
|
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
|
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||||
|
// fmt.Println("--------------------------------")
|
||||||
|
|
||||||
|
return so, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) schemaToGraph() {
|
||||||
|
schemaType := s.schema.EffectiveType()
|
||||||
|
switch schemaType {
|
||||||
|
case "object":
|
||||||
|
// TODO: see if we need to connect these to the JSON graph
|
||||||
|
|
||||||
|
// each prop is a key
|
||||||
|
for _, prop := range s.schema.Properties {
|
||||||
|
// name of key
|
||||||
|
name := prop.Name
|
||||||
|
keyNode := &PDA{
|
||||||
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
|
||||||
|
prevNode := keyNode
|
||||||
|
for _, r := range name {
|
||||||
|
runeNode := &PDA{
|
||||||
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
// fmt.Println("runeNode created", runeNode.State)
|
||||||
|
// fmt.Printf("runeNode created %c\n", r)
|
||||||
|
|
||||||
|
// since alloc on heap connections wil still map
|
||||||
|
prevNode.TransitionEdges[r] = runeNode
|
||||||
|
prevNode = runeNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// point to end of object key node after all chars are done
|
||||||
|
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||||
|
|
||||||
|
// link to value node
|
||||||
|
// Create a node for the end of the key (after the closing quote)
|
||||||
|
stringEndNode := &PDA{
|
||||||
|
State: StateInStructuredKey,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
prevNode = stringEndNode
|
||||||
|
|
||||||
|
// Add transition for colon after key
|
||||||
|
colonNode := &PDA{
|
||||||
|
State: StateInColon,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[':'] = colonNode
|
||||||
|
prevNode = colonNode
|
||||||
|
|
||||||
|
// Add transition for space after colon
|
||||||
|
spaceNode := &PDA{
|
||||||
|
State: StateInSpaceToValue,
|
||||||
|
TransitionEdges: make(map[rune]*PDA),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[' '] = spaceNode
|
||||||
|
prevNode = spaceNode
|
||||||
|
|
||||||
|
value := prop.Type
|
||||||
|
switch value {
|
||||||
|
case "object":
|
||||||
|
fmt.Println("object under key: ", name)
|
||||||
|
case "array":
|
||||||
|
fmt.Println("array under key: ", name)
|
||||||
|
case "string":
|
||||||
|
fmt.Println("string under key: ", name)
|
||||||
|
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
||||||
|
case "number":
|
||||||
|
fmt.Println("number under key: ", name)
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
case "boolean":
|
||||||
|
fmt.Println("boolean under key: ", name)
|
||||||
|
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||||
|
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
|
// points to start of the key
|
||||||
|
s.propToNodeMap[name] = keyNode
|
||||||
|
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: do values + recursion
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
|
||||||
|
if s.schema == nil {
|
||||||
|
return s.pdaSampler.Apply(logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
// TODO: doesnt account for multi rune case
|
||||||
|
case StateInObjectKey:
|
||||||
|
if s.propIdx > len(s.schema.Properties)-1 {
|
||||||
|
return nil, fmt.Errorf("propIdx out of bounds")
|
||||||
|
}
|
||||||
|
// fmt.Println("in object key - structured outputs")
|
||||||
|
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||||
|
// simple case
|
||||||
|
s.propIdx++
|
||||||
|
fmt.Println("propIdx", s.propIdx)
|
||||||
|
prop := s.schema.Properties[s.propIdx]
|
||||||
|
fmt.Println("prop", prop.Name)
|
||||||
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
|
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||||
|
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
|
||||||
|
// Will only happen for the last prop - can also be precomputed.
|
||||||
|
if s.propIdx == len(s.schema.Properties)-1 {
|
||||||
|
// todo: if i incremenet propidx then i know im in last value as well
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
case StateInObjectEnd:
|
||||||
|
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
||||||
|
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
||||||
|
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||||
|
s.propIdx++
|
||||||
|
|
||||||
|
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
||||||
|
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||||
|
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||||
|
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||||
|
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
||||||
|
|
||||||
|
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
||||||
|
s.propIdx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.pdaSampler.Apply(logits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||||
|
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.schema == nil {
|
||||||
|
// Don't need to update state for unconstrained JSON sampling
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.pdaSampler.curNode.State {
|
||||||
|
case StateInObjectKey:
|
||||||
|
s.propIdx++
|
||||||
|
fmt.Println("propIdx", s.propIdx)
|
||||||
|
prop := s.schema.Properties[s.propIdx]
|
||||||
|
fmt.Println("prop", prop.Name)
|
||||||
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||||
|
// TODO: this does not work - mike
|
||||||
|
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// fmt.Println("str", str)
|
||||||
|
|
||||||
|
return tokenSlice, nil
|
||||||
|
default:
|
||||||
|
return tokenSlice, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
352
sample/structured_python.go
Normal file
352
sample/structured_python.go
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PythonState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PythonStateStart PythonState = iota
|
||||||
|
StateInFunction
|
||||||
|
StateInFunctionArgs
|
||||||
|
StateInFunctionArgsType
|
||||||
|
StateInFunctionEnd
|
||||||
|
PStateInString
|
||||||
|
PStateInStringEnd
|
||||||
|
PStateInNumber
|
||||||
|
PStateInList
|
||||||
|
PStateInListEnd
|
||||||
|
PStateInDict
|
||||||
|
PStateInDictEnd
|
||||||
|
PStateInTuple
|
||||||
|
PStateInTupleEnd
|
||||||
|
PStateTerminate
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s PythonState) String() string {
|
||||||
|
switch s {
|
||||||
|
case PythonStateStart:
|
||||||
|
return "PythonStateStart"
|
||||||
|
case StateInFunction:
|
||||||
|
return "StateInFunction"
|
||||||
|
case StateInFunctionArgs:
|
||||||
|
return "StateInFunctionArgs"
|
||||||
|
case StateInFunctionArgsType:
|
||||||
|
return "StateInFunctionArgsType"
|
||||||
|
case StateInFunctionEnd:
|
||||||
|
return "StateInFunctionEnd"
|
||||||
|
case PStateInString:
|
||||||
|
return "PStateInString"
|
||||||
|
case PStateInStringEnd:
|
||||||
|
return "PStateInStringEnd"
|
||||||
|
case PStateInNumber:
|
||||||
|
return "PStateInNumber"
|
||||||
|
case PStateInList:
|
||||||
|
return "PStateInList"
|
||||||
|
case PStateInListEnd:
|
||||||
|
return "PStateInListEnd"
|
||||||
|
case PStateInDict:
|
||||||
|
return "PStateInDict"
|
||||||
|
case PStateInDictEnd:
|
||||||
|
return "PStateInDictEnd"
|
||||||
|
case PStateInTuple:
|
||||||
|
return "PStateInTuple"
|
||||||
|
case PStateInTupleEnd:
|
||||||
|
return "PStateInTupleEnd"
|
||||||
|
case PStateTerminate:
|
||||||
|
return "PStateTerminate"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("PythonState(%d)", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var PythonStates = []PythonState{
|
||||||
|
PythonStateStart,
|
||||||
|
StateInFunction,
|
||||||
|
StateInFunctionArgs,
|
||||||
|
StateInFunctionArgsType,
|
||||||
|
StateInFunctionEnd,
|
||||||
|
PStateInString,
|
||||||
|
PStateInStringEnd,
|
||||||
|
PStateInNumber,
|
||||||
|
PStateInList,
|
||||||
|
PStateInListEnd,
|
||||||
|
PStateInDict,
|
||||||
|
PStateInDictEnd,
|
||||||
|
PStateInTuple,
|
||||||
|
PStateInTupleEnd,
|
||||||
|
PStateTerminate,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Node struct {
|
||||||
|
State PythonState
|
||||||
|
TransitionEdges map[rune]*Node
|
||||||
|
MaskTokenIDToNode map[int32]*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNode(state PythonState) *Node {
|
||||||
|
return &Node{
|
||||||
|
State: state,
|
||||||
|
TransitionEdges: make(map[rune]*Node),
|
||||||
|
MaskTokenIDToNode: make(map[int32]*Node),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PythonFunction struct {
|
||||||
|
Name string
|
||||||
|
Args []string
|
||||||
|
Types []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PythonSampler struct {
|
||||||
|
stateToNodes map[PythonState]*Node
|
||||||
|
proc model.TextProcessor
|
||||||
|
decodedToks []string
|
||||||
|
curNode *Node
|
||||||
|
completed int
|
||||||
|
functions []PythonFunction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||||
|
s.proc = proc
|
||||||
|
s.functions = functions
|
||||||
|
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||||
|
for i := range proc.Vocab().Values {
|
||||||
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decodedToks[i] = token
|
||||||
|
}
|
||||||
|
s.decodedToks = decodedToks
|
||||||
|
s.BuildGraph()
|
||||||
|
for _, function := range functions {
|
||||||
|
prevNode := s.stateToNodes[PythonStateStart]
|
||||||
|
|
||||||
|
for _, r := range function.Name {
|
||||||
|
nextNode := NewNode(StateInFunction)
|
||||||
|
prevNode.TransitionEdges[r] = nextNode
|
||||||
|
if err := s.CreateMask(nextNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("prevNode", prevNode.State)
|
||||||
|
fmt.Printf("transition edge: %q\n", r)
|
||||||
|
fmt.Println("nextNode", nextNode.State)
|
||||||
|
prevNode = nextNode
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
for i, arg := range function.Args {
|
||||||
|
for _, r := range arg {
|
||||||
|
nextNode := NewNode(StateInFunctionArgs)
|
||||||
|
prevNode.TransitionEdges[r] = nextNode
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = nextNode
|
||||||
|
}
|
||||||
|
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
// prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
|
||||||
|
s.CreateMask(prevNode)
|
||||||
|
prevNode = prevNode.TransitionEdges['=']
|
||||||
|
switch function.Types[i] {
|
||||||
|
case "string":
|
||||||
|
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
|
||||||
|
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||||
|
case "number":
|
||||||
|
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
|
||||||
|
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
s.curNode = s.stateToNodes[PythonStateStart]
|
||||||
|
fmt.Println("curNode", s.curNode.State)
|
||||||
|
fmt.Println("transition edges", s.curNode.TransitionEdges)
|
||||||
|
if err := s.CreateMask(s.curNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
|
||||||
|
for tokenID, node := range s.curNode.MaskTokenIDToNode {
|
||||||
|
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) BuildGraph() error {
|
||||||
|
s.stateToNodes = make(map[PythonState]*Node)
|
||||||
|
for _, state := range PythonStates {
|
||||||
|
s.stateToNodes[state] = NewNode(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, state := range s.stateToNodes {
|
||||||
|
if err := s.CreateMask(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String
|
||||||
|
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
|
||||||
|
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
|
||||||
|
|
||||||
|
// String end
|
||||||
|
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
// Number
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||||
|
}
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
|
||||||
|
if s.curNode.State == PStateTerminate {
|
||||||
|
logits, err := finish(s, logits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) UpdateState(token int32) error {
|
||||||
|
mappedString, err := s.proc.Decode([]int32{token})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||||
|
|
||||||
|
if s.curNode.State == PStateTerminate {
|
||||||
|
if s.proc.Is(token, model.SpecialEOS) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mappedString == "\"" {
|
||||||
|
if s.curNode.State == PStateInStringEnd {
|
||||||
|
s.completed++
|
||||||
|
}
|
||||||
|
if s.completed == len(s.functions) {
|
||||||
|
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
s.CreateMask(s.curNode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.curNode = nextNode
|
||||||
|
fmt.Println("curNode", s.curNode.State)
|
||||||
|
for r, node := range s.curNode.TransitionEdges {
|
||||||
|
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
|
||||||
|
}
|
||||||
|
if err := s.CreateMask(s.curNode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) CreateMask(node *Node) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("node cannot be nil")
|
||||||
|
}
|
||||||
|
for i := range s.decodedToks {
|
||||||
|
token := s.decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
curNode := node
|
||||||
|
valid := true
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
for _, r := range token {
|
||||||
|
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if curNode == nil || !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
if curNode.State == StateInFunction {
|
||||||
|
// fmt.Println("cm curNode", curNode.State)
|
||||||
|
// fmt.Println("cm token", s.decodedToks[i])
|
||||||
|
}
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
|
||||||
|
if consumedSpecialRunes[r] {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific rune transition
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||||
|
// fmt.Println("next node", nextNode)
|
||||||
|
if specialRune {
|
||||||
|
if curNode.State == nextNode.State {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
consumedSpecialRunes[r] = true
|
||||||
|
}
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sentinel value - if present, any rune is valid
|
||||||
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||||
|
return nextNode, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
|
||||||
|
// Create a new slice with same length as logits, initialized to -Inf
|
||||||
|
maskedLogits := make([]float32, len(logits))
|
||||||
|
for i := range maskedLogits {
|
||||||
|
maskedLogits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update values for valid token IDs from the mask map
|
||||||
|
for tokenID := range node.MaskTokenIDToNode {
|
||||||
|
if int(tokenID) < len(logits) {
|
||||||
|
maskedLogits[tokenID] = logits[tokenID]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskedLogits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
|
|||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Test with very high p value
|
||||||
tokens = topP(tokens, 0.95)
|
got := topP(tokens, 1.0)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
// Should keep all tokens since p is 1
|
||||||
if len(tokens) > 3 {
|
if len(got) != len(input) {
|
||||||
|
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = topP(tokens, 0.95)
|
||||||
|
|
||||||
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test edge case - ensure at least one token remains
|
// Test edge case - ensure at least one token remains
|
||||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
input = []float32{-1e6, -1e6, -1e7}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topP(tokens, 0.0) // Very small p
|
got = topP(tokens, 0.0)
|
||||||
if len(tokens) < 1 {
|
if len(got) < 1 {
|
||||||
t.Error("topP should keep at least one token")
|
t.Error("topP should keep at least one token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 1e-10)
|
||||||
|
if len(got) == 0 {
|
||||||
|
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
|
|||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with single token
|
||||||
|
tokens = toTokens(input[:1])
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.1)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
input = []float32{1e-10, 1e-10, 1e-10}
|
input = []float32{1e-10, 1e-10, 1e-10}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = minP(tokens, 1.0)
|
tokens = minP(tokens, 1.0)
|
||||||
if len(tokens) < 1 {
|
if len(tokens) < 1 {
|
||||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||||
}
|
got := minP(tokens, 1.0)
|
||||||
}
|
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
if len(got) != 1 {
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||||
tokens := toTokens(input)
|
}
|
||||||
|
|
||||||
tokens = topK(tokens, 20)
|
// Test with normal p value
|
||||||
|
got = minP(tokens, 0.2)
|
||||||
|
|
||||||
for i := 1; i < len(tokens); i++ {
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
if tokens[i].value > tokens[i-1].value {
|
if len(got) > 3 {
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = minP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != len(tokens) {
|
||||||
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -60,6 +59,11 @@ var (
|
|||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
ErrCached = errors.New("cached")
|
ErrCached = errors.New("cached")
|
||||||
|
|
||||||
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||||
|
// incomplete due to one or more layer download failures. Users that
|
||||||
|
// want specific errors should use [WithTrace].
|
||||||
|
ErrIncomplete = errors.New("incomplete")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults
|
// Defaults
|
||||||
@@ -213,12 +217,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@@ -278,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
|
|
||||||
func UserAgent() string {
|
func UserAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
version := buildinfo.Main.Version
|
||||||
|
if version == "(devel)" {
|
||||||
|
// When using `go run .` the version is "(devel)". This is seen
|
||||||
|
// as an invalid version by ollama.com and so it defaults to
|
||||||
|
// "needs upgrade" for some requests, such as pulls. These
|
||||||
|
// checks can be skipped by using the special version "v0.0.0",
|
||||||
|
// so we set it to that here.
|
||||||
|
version = "v0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||||
buildinfo.Main.Version,
|
version,
|
||||||
runtime.GOARCH,
|
runtime.GOARCH,
|
||||||
runtime.GOOS,
|
runtime.GOOS,
|
||||||
runtime.Version(),
|
runtime.Version(),
|
||||||
@@ -425,13 +434,14 @@ func canRetry(err error) bool {
|
|||||||
//
|
//
|
||||||
// It always calls update with a nil error.
|
// It always calls update with a nil error.
|
||||||
type trackingReader struct {
|
type trackingReader struct {
|
||||||
r io.Reader
|
l *Layer
|
||||||
n *atomic.Int64
|
r io.Reader
|
||||||
|
update func(l *Layer, n int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||||
n, err = r.r.Read(p)
|
n, err = r.r.Read(p)
|
||||||
r.n.Add(int64(n))
|
r.update(r.l, int64(n), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,6 +457,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@@ -456,11 +471,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@@ -468,99 +479,97 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
|
var expected int64
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
expected += l.Size
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var received atomic.Int64
|
||||||
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
received.Add(l.Size)
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, progress.Load(), err)
|
// Chunksum stream interrupted. Note in trace
|
||||||
|
// log and let in-flight downloads complete.
|
||||||
|
// This will naturally trigger ErrIncomplete
|
||||||
|
// since received < expected bytes.
|
||||||
|
t.update(l, 0, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
received.Add(cs.Chunk.Size())
|
||||||
if err != nil {
|
} else {
|
||||||
return err
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
err := func() error {
|
wg.Done()
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
}()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
// Count bytes towards
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
// progress, as they arrive, so
|
if err != nil {
|
||||||
// that our bytes piggyback
|
return err
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||||
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close writer immediately after downloads finish, not at Pull
|
||||||
|
// exit. Using defer would keep file descriptors open until all
|
||||||
|
// layers complete, potentially exhausting system limits with
|
||||||
|
// many layers.
|
||||||
|
//
|
||||||
|
// The WaitGroup tracks when all chunks finish downloading,
|
||||||
|
// allowing precise writer closure in a background goroutine.
|
||||||
|
// Each layer briefly uses one extra goroutine while at most
|
||||||
|
// maxStreams()-1 chunks download in parallel.
|
||||||
|
//
|
||||||
|
// This caps file descriptors at maxStreams() instead of
|
||||||
|
// growing with layer count.
|
||||||
|
g.Go(func() error {
|
||||||
|
wg.Wait()
|
||||||
|
chunked.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if received.Load() != expected {
|
||||||
|
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,6 +25,28 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ExampleRegistry_cancelOnFirstError() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
// Discontinue pulling layers if there is an
|
||||||
|
// error instead of continuing to pull more
|
||||||
|
// data.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var r Registry
|
||||||
|
if err := r.Pull(ctx, "model"); err != nil {
|
||||||
|
// panic for demo purposes
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestMarshalJSON(t *testing.T) {
|
func TestManifestMarshalJSON(t *testing.T) {
|
||||||
// All manifests should contain an "empty" config object.
|
// All manifests should contain an "empty" config object.
|
||||||
var m Manifest
|
var m Manifest
|
||||||
@@ -56,21 +79,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
|
|
||||||
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
||||||
//
|
//
|
||||||
// empty: no data
|
// empty: no data
|
||||||
// zero: no layers
|
// zero: no layers
|
||||||
// single: one layer with the contents "exists"
|
// single: one layer with the contents "exists"
|
||||||
// multiple: two layers with the contents "exists" and "here"
|
// multiple: two layers with the contents "exists" and "here"
|
||||||
// notfound: a layer that does not exist in the cache
|
// notfound: a layer that does not exist in the cache
|
||||||
// null: one null layer (e.g. [null])
|
// null: one null layer (e.g. [null])
|
||||||
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
||||||
// invalid: a layer with invalid JSON data
|
// invalid: a layer with invalid JSON data
|
||||||
//
|
//
|
||||||
// Tests that want to ensure the client does not communicate with the upstream
|
// Tests that want to ensure the client does not communicate with the upstream
|
||||||
// registry should pass a nil handler, which will cause a panic if
|
// registry should pass a nil handler, which will cause a panic if
|
||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@@ -88,7 +111,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,3 +790,79 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPullChunksums(t *testing.T) {
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
content := "hello"
|
||||||
|
var chunksums string
|
||||||
|
contentDigest := func() blob.Digest {
|
||||||
|
return blob.DigestFromBytes(content)
|
||||||
|
}
|
||||||
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||||
|
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||||
|
w.Header().Set("Content-Location", loc)
|
||||||
|
io.WriteString(w, chunksums)
|
||||||
|
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||||
|
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected request: %v", r)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||||
|
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Logf("Update: %v %d %v", l, n, err)
|
||||||
|
mu.Lock()
|
||||||
|
reads = append(reads, n)
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||||
|
blob.DigestFromBytes("hel"),
|
||||||
|
blob.DigestFromBytes("lo"),
|
||||||
|
)
|
||||||
|
err := rc.Pull(ctx, "test")
|
||||||
|
check(err)
|
||||||
|
wantReads := []int64{
|
||||||
|
0, // initial signaling of layer pull starting
|
||||||
|
3, // first chunk read
|
||||||
|
2, // second chunk read
|
||||||
|
}
|
||||||
|
if !slices.Equal(reads, wantReads) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||||
|
}
|
||||||
|
|
||||||
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
check(err)
|
||||||
|
mg, err := rc.ResolveLocal("test")
|
||||||
|
check(err)
|
||||||
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
|
}
|
||||||
|
for i := range mg.Layers {
|
||||||
|
_, err = c.Get(mg.Layers[i].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing chunks
|
||||||
|
content = "llama"
|
||||||
|
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||||
|
err = rc.Pull(ctx, "missingchunks")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error because of missing chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ type params struct {
|
|||||||
//
|
//
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||||
// defined to default to true if not present, so we need a way to check
|
// defined to default to true if not present, so we need a way to check
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
// if the client decisively set it to false. So, we use a pointer to a
|
||||||
// bool. Gross.
|
// bool. Gross.
|
||||||
//
|
//
|
||||||
// Use [stream()] to get the correct value for this field.
|
// Use [stream()] to get the correct value for this field.
|
||||||
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
progress := make(map[*ollama.Layer]int64)
|
progress := make(map[*ollama.Layer]int64)
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||||
pushUpdate := func() {
|
flushProgress := func() {
|
||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): Flushing every layer in one update doesn't
|
||||||
// needing to flush out them all in one big update. We _could_
|
// scale well. We could flush only the modified layers or track
|
||||||
// just flush on the changed ones, or just track the whole
|
// the full download. Needs further consideration, though it's
|
||||||
// download. Needs more thought. This is fine for now.
|
// fine for now.
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
maps.Copy(progressCopy, progress)
|
maps.Copy(progressCopy, progress)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
for l, n := range progress {
|
for l, n := range progressCopy {
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer flushProgress()
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
|
||||||
start := sync.OnceFunc(func() {
|
start := sync.OnceFunc(func() {
|
||||||
pushUpdate()
|
flushProgress() // flush initial state
|
||||||
t.Reset(100 * time.Millisecond)
|
t.Reset(100 * time.Millisecond)
|
||||||
})
|
})
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
start() // flush initial state
|
// Block flushing progress updates until every
|
||||||
|
// layer is accounted for. Clients depend on a
|
||||||
|
// complete model size to calculate progress
|
||||||
|
// correctly; if they use an incomplete total,
|
||||||
|
// progress indicators would erratically jump
|
||||||
|
// as new layers are registered.
|
||||||
|
start()
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
progress[l] = n
|
progress[l] += n
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var status string
|
var status string
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<start_of_turn>user
|
||||||
|
{{- if and (eq $i 1) $.System }}
|
||||||
|
{{ $.System }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ end }}
|
||||||
|
{{- if $last }}<start_of_turn>model
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<end_of_turn>"
|
||||||
|
],
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
@@ -87,6 +87,10 @@
|
|||||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"name": "gemma-instruct"
|
"name": "gemma-instruct"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||||
|
"name": "gemma3-instruct"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
"name": "llama3-instruct"
|
"name": "llama3-instruct"
|
||||||
|
|||||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
Reference in New Issue
Block a user