Compare commits

...

52 Commits

Author SHA1 Message Date
ParthSareen
4450f871db wip 2025-03-25 16:45:27 -07:00
ParthSareen
5ec6bb52a0 prototyping 2025-03-25 15:00:14 -07:00
Blake Mizerany
1fd9967558 grammar: introduce new grammar package
This package provides a way to convert JSON schemas to equivalent EBNF.
It is intended to be a replacement to llama.cpp's schema_to_grammar.

This is still an early version and does not yet support all JSON schema
features. The to-do list includes:

- minumum/maximum constraints on integer types
- minLength/maxLength constraints on string types
- defs and refs
2025-03-24 11:58:06 -07:00
Matheus C. França
131f0355a5 readme: add ollama-d library (#9907) 2025-03-24 09:25:58 -07:00
Blake Mizerany
ce929984a3 server/internal/client/ollama: fix file descriptor management in Pull (#9931)
Close chunked writers as soon as downloads complete, rather than
deferring closure until Pull exits. This prevents exhausting file
descriptors when pulling many layers.

Instead of unbounded defers, use a WaitGroup and background goroutine
to close each chunked writer as soon as its downloads finish.

Also rename 'total' to 'received' for clarity.
2025-03-21 16:16:38 -07:00
Michael Yang
4b34930a31 Merge pull request #9897 from ollama/mxyng/chunk-load
ml/backend/ggml: load tensors in 128KiB chunks
2025-03-21 14:47:13 -07:00
Michael Yang
74bd09652d ml/backend/ggml: load tensors in 32KiB chunks 2025-03-21 14:43:52 -07:00
Bruce MacDonald
fb6252d786 benchmark: performance of running ollama server (#8643) 2025-03-21 13:08:20 -07:00
Blake Mizerany
c794fef2f2 server/internal/client/ollama: persist through chunk download errors (#9923) 2025-03-21 13:03:43 -07:00
Parth Sareen
00ebda8cc4 Revert "parser: remove role validation from Modelfile parser" (#9917)
This reverts commit ffbfe833da.
2025-03-21 12:38:09 -07:00
Parth Sareen
d14ce75b95 docs: update final response for /api/chat stream (#9919) 2025-03-21 12:35:47 -07:00
Jesse Gross
2d6eac9084 kvcache: Optimize sliding window attention
Currently sliding window attention allocates and uses the full
context size and just masks out any tokens that are outside of the
window. However, we really only need (roughly) the sliding window
size.

At large context sizes this improves two things:
 - Memory allocated - since the fully context size is allocated up front,
   memory requirements drop substantially. On Gemma3:4b with a 32k
   context window, total memory usage (including weights and non-sliding
   layers) drops from ~20GB to ~8GB.
 - Computation - ranges that are completely outside of the sliding
   window are now removed from the tensors that are returned from the
   cache rather than simply being masked out. This results in more
   efficient processing, scaling with the size of the context that
   has actually been used.

Notable, this does not update the scheduler for any model to be aware of
the smaller memory requirements. This is difficult for Gemma3 because
the layers are heterogeneous between sliding and non-sliding attention.
As a result, while actual memory consumption will be reduced, the
scheduler will over-estimate the requirements of the model. This means
that splitting between GPUs or GPUs and CPUs will still be suboptimal.

Bug #9730
2025-03-21 11:20:19 -07:00
Jesse Gross
3ed7ad3ab3 kvcache: Pass granular cache size into implementations
Currently the runner computes the kv size needed and creates a
cache of that size. This is the context size times number of
parallel sequences.

Cache implementations can make better decisions about their memory
usage, so instead pass in the required capacity, number of sequences
and maximum batch size. For now, the causal cache just uses this to
compute the size in the same way as before.
2025-03-21 11:20:19 -07:00
Patrick Devine
6d1103048e fix: show correct bool value for kv in verbose show information (#9928) 2025-03-21 11:13:54 -07:00
Jesse Gross
0ff28758b3 ollamarunner: Provide mechanism for backends to report loading progress
This enables the runner to report progress back to the Ollama server,
both for showing status to the user and also to prevent the server
from killing the runner if it thinks things have stalled.

Most of the infrastructure was already there, this extends it to
be available to the backends.
2025-03-21 10:44:26 -07:00
Jesse Gross
d3e9ca3eda kvcache: Account for source tensors in defrag operation count
Defragging the KV cache can generate a lot of operations, so we
need to be careful that we don't overflow the number that the graph
can support. We currently account for all of the nodes that we add
to the graph for each move but we also need to include the original
cache tensors as well.

Fixes #9904
2025-03-21 10:42:19 -07:00
Jesse Gross
0fbfcf3c9c model: Pass input tensor instead of raw data to models
Rather than directly giving the input data to models, we can
pass a tensor instead. In the short term, this saves some duplicated
code.

Longer term, we will want to overlap setting up the next batch with
processing of the current one. In this case, we will only have the
shape of tensor but it will not be loaded with data at the time of
graph generation. By passing only a tensor to models now, we set up
this possibility and prevent them from relying on data that they won't
have in the future.

Although the same could be done for Positions and Outputs, in some
cases we either need the raw input data or don't use them at all.
Therefore, for now we leave them as they are and allow models to
convert them to tensors as needed.
2025-03-20 13:28:13 -07:00
Jesse Gross
0c220935bd input: Rename Options to Batch
Options is no longer very descriptive of this struct.
2025-03-20 13:28:13 -07:00
rylativity
ffbfe833da parser: remove role validation from Modelfile parser (#9874)
* updates parser/parser.go to allow arbitrary roles in Modelfile MESSAGE blocks
2025-03-20 13:11:17 -07:00
Parth Sareen
42a14f7f63 sample: add error handling for empty logits (#9740) 2025-03-20 11:11:18 -07:00
Patrick Devine
f8c3dbe5b5 templates: add autotemplate for gemma3 (#9880)
This change allows the gemma3 template to be autodetected during `ollama
create`.
2025-03-20 00:15:30 -07:00
Jesse Gross
b078dd157c gemma2: Remove second call to Rows
Looks like a merge conflict that broke the model.
2025-03-19 17:28:49 -07:00
Blake Mizerany
2ddacd7516 server/internal/client/ollama: confirm all chunksums were received (#9893)
If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
2025-03-19 14:59:57 -07:00
Jeffrey Morgan
da0e345200 ml: use input context for extracting outputs (#9875) 2025-03-18 18:08:19 -07:00
Bruce MacDonald
df94175a0f ggml: return error on failure to read tensor data (#9872)
When converting a ggml model if there is a failure to read tensor data a nil error value was being returned. It should be assigned to the actual error from reading.
2025-03-18 16:51:33 -07:00
Bruce MacDonald
61a8825216 convert: return name of unsupported architecture (#9862)
When a model's architecture cannot be converted return the name of the unsupported arch in the error message.
2025-03-18 10:38:28 -07:00
Michael Yang
021dcf089d Merge pull request #9824 from ollama/mxyng/sched
conditionally enable parallel pipelines
2025-03-17 15:41:37 -07:00
Jesse Gross
bf24498b1e ollamarunner: Check for minBatch of context space when shifting
Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.
2025-03-17 15:33:16 -07:00
Bruce MacDonald
95e271d98f runner: remove cache prompt flag from ollama runner (#9826)
We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
2025-03-17 15:11:15 -07:00
Jeffrey Morgan
364629b8d6 ml/backend/ggml: allocate memory with malloc when loading model (#9822) 2025-03-17 13:32:40 -07:00
Parth Sareen
108fe02165 sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
2025-03-17 11:24:18 -07:00
Michael Yang
4561fff36e conditionally enable parallel pipelines 2025-03-17 09:46:07 -07:00
Daniel Hiltgen
50b5962042 Add support for ROCm gfx1151 (#9773) 2025-03-17 09:33:57 -07:00
Louis Beaumont
e27e4a3c1b readme: add screenpipe to community integrations (#9786) 2025-03-16 21:56:42 -04:00
zeo
088514bbd4 readme: add Ellama to list of community integrations (#9800) 2025-03-16 21:54:43 -04:00
Patrick Devine
2c8b484643 fix: correctly save in interactive mode (#9788)
This fixes the case where a FROM line in previous modelfile points to a
file which may/may not be present in a different ollama instance. We
shouldn't be relying on the filename though and instead just check if
the FROM line was instead a valid model name and point to that instead.
2025-03-15 12:09:02 -07:00
Blake Mizerany
8294676150 server/internal/client/ollama: set User-Agent for registry client (#9775)
This sets the agent header in DefaultRegistry to include the version of
the client, OS, and architecture in the previous format, with a minor
twist.

Note: The version is obtained from the build info, instead of the
version in version.Version, which should not longer be necessary, but we
can remove in a future commit. Using the build info is more accurate and
also provides extra build information if the build is not tagged, and if
it is "dirty". Previously, the version was just "0.0.0" with no other
helpful information. The ollama.com registry and others handle this
swimmingly.
2025-03-14 18:33:07 -07:00
Patrick Devine
ef378ad673 gemma3 quantization (#9776) 2025-03-14 17:41:07 -07:00
Daniel Hiltgen
2d2247e59e Align versions for local builds (#9635)
Darwin was using a different pattern for the version string
than linux or windows.
2025-03-14 15:44:08 -07:00
Jesse Gross
7bf793a600 gemma3: Allow multiple image in a single input
Previously processing multiple images in a batch would trigger
segfaults so sending images together was disabled as a way to
mitigate this. The trigger was processing one image on the CPU
and one on the GPU.

This can no longer happen:
 - The vision encoder is now on the GPU so both images would be
   processed on the GPU.
 - We require images to be fully contained in a batch and each
   image including its special tokens is over half the batch size.
   As a result, we will never get two images in the same batch.

Fixes #9731
2025-03-14 15:38:54 -07:00
Jesse Gross
282bfaaa95 ollamarunner: Use a separate context per multimodal input
Currently there is a single context per sequence, shared all by
all multimodal inputs. Since we build a vision encoder graph per
image, with a large number of inputs we can eventually hit the
maximum number of graph nodes per context.

This changes to use a separate context for each image, ensuring
that available resource limits are consistent.
2025-03-14 15:38:54 -07:00
Jesse Gross
9679f40146 ml: Allow models to constrain inputs to a single batch
Models may require that a set of inputs all be processed as part
of the same batch. For example, if an image has multiple patches
with fully connected attention between them, we should not split
the batch in the middle of an image.

Fixes #9697
2025-03-14 15:38:54 -07:00
Bruce MacDonald
3892c3a703 llm: remove internal subprocess req and resp types (#9324)
This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
2025-03-14 15:21:53 -07:00
Blake Mizerany
4e320b8b90 server/internal/chunks: remove chunks package (#9755) 2025-03-14 08:57:59 -07:00
Blake Mizerany
eb2b22b042 server/internal/client: use chunksums for concurrent blob verification (#9746)
Replace large-chunk blob downloads with parallel small-chunk
verification to solve timeout and performance issues. Registry users
experienced progressively slowing download speeds as large-chunk
transfers aged, often timing out completely.

The previous approach downloaded blobs in a few large chunks but
required a separate, single-threaded pass to read the entire blob back
from disk for verification after download completion.

This change uses the new chunksums API to fetch many smaller
chunk+digest pairs, allowing concurrent downloads and immediate
verification as each chunk arrives. Chunks are written directly to their
final positions, eliminating the entire separate verification pass.

The result is more reliable downloads that maintain speed throughout the
transfer process and significantly faster overall completion, especially
over unstable connections or with large blobs.
2025-03-13 22:18:29 -07:00
Michael Yang
4ea4d2b189 Merge pull request #9703 from ollama/mxyng/gemma3-memory
count gemma3 vision tensors
2025-03-13 16:56:34 -07:00
Michael Yang
8d76fa23ef count non-repeating vision layers 2025-03-13 16:53:29 -07:00
Bradley Erickson
74b44fdf8f docs: Add OLLAMA_ORIGINS for browser extension support (#9643) 2025-03-13 16:35:20 -07:00
Michael Yang
65b88c544f fix divide by zero 2025-03-13 16:35:00 -07:00
Michael Yang
a422ba39c9 roughly count gemma3 graph
the largest operation is by far (q @ k) so just count that for
simplicity
2025-03-13 16:35:00 -07:00
Michael Yang
d2ec22371e count all vision tensors 2025-03-13 16:35:00 -07:00
Michael Yang
033cec232a count gemma3 vision tensors 2025-03-13 16:34:42 -07:00
91 changed files with 5026 additions and 1360 deletions

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
}
],

View File

@@ -392,6 +392,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
### Cloud
@@ -510,6 +512,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
### Mobile

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -703,6 +703,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:

View File

@@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
@@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434
@@ -757,3 +761,132 @@ func TestCreateHandler(t *testing.T) {
})
}
}
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
type MultilineState int
@@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{
Name: name,
From: cmp.Or(opts.ParentModel, opts.Model),
Model: name,
From: cmp.Or(parentModel, opts.Model),
}
if opts.System != "" {

View File

@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return errors.New("unsupported architecture")
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {

View File

@@ -558,6 +558,10 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -187,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored?

View File

@@ -583,39 +583,52 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*kv("num_channels")*maxNumTiles +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
}
return weights, graphSize
}

22
grammar/bench_test.go Normal file
View File

@@ -0,0 +1,22 @@
//go:build go1.24
package grammar
import "testing"
func BenchmarkFromSchema(b *testing.B) {
for tt := range testCases(b) {
b.Run("", func(b *testing.B) {
s := []byte(tt.schema)
b.ReportAllocs()
for b.Loop() {
_, err := FromSchema(nil, s)
if err != nil {
b.Fatalf("GrammarFromSchema: %v", err)
}
}
})
return
}
}

227
grammar/grammar.go Normal file
View File

@@ -0,0 +1,227 @@
package grammar
import (
"bytes"
"encoding/json"
"fmt"
"iter"
"strconv"
"github.com/ollama/ollama/grammar/jsonschema"
)
const jsonTerms = `
# Unicode
#
# Unicode characters can be specified directly in the grammar, for example
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
# (\UXXXXXXXX).
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
# JSON grammar from RFC 7159
null ::= "null"
object ::= "{" (kv ("," kv)*)? "}"
array ::= "[" (value ("," value)*)? "]"
kv ::= string ":" value
integer ::= "0" | [1-9] [0-9]*
number ::= "-"? integer frac? exp?
frac ::= "." [0-9]+
exp ::= ("e" | "E") ("+" | "-") [0-9]+
string ::= "\"" char* "\""
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
char ::= [^"\\] | escape
space ::= (" " | "\t" | "\n" | "\r")*
hex ::= [0-9] | [a-f] | [A-F]
boolean ::= "true" | "false"
value ::= object | array | string | number | boolean | "null"
# User-defined
`
// FromSchema generates a grammar from a JSON schema.
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
var s *jsonschema.Schema
if err := json.Unmarshal(jsonSchema, &s); err != nil {
return nil, err
}
var g builder
// "root" is the only rule that is guaranteed to exist, so we start
// with its length for padding, and then adjust it as we go.
g.pad = len("root")
for id := range dependencies("root", s) {
g.pad = max(g.pad, len(id))
}
g.b.WriteString(jsonTerms)
ids := make(map[*jsonschema.Schema]string)
for id, s := range dependencies("root", s) {
ids[s] = id
g.define(id)
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
}
g.define("root")
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
g.define("") // finalize the last rule
return g.b.Bytes(), nil
}
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
switch typ := s.EffectiveType(); typ {
case "array":
if len(s.PrefixItems) == 0 && s.Items == nil {
g.u("array")
} else {
g.q("[")
for i, s := range s.PrefixItems {
if i > 0 {
g.q(",")
}
g.u(ids[s])
}
if s.Items != nil {
g.u("(")
if len(s.PrefixItems) > 0 {
g.q(",")
}
g.u(ids[s.Items])
g.u(")*")
}
g.q("]")
}
case "object":
if len(s.Properties) == 0 {
g.u("object")
} else {
g.q("{")
for i, p := range s.Properties {
name := ids[p]
if i > 0 {
g.q(",")
}
g.q(p.Name)
g.q(":")
g.u(name)
}
g.q("}")
}
case "number":
buildConstrainedNumber(g, s)
case "string":
if len(s.Enum) == 0 {
g.u("string")
} else {
g.u("(")
for i, e := range s.Enum {
if i > 0 {
g.q("|")
}
g.q(string(e))
}
g.u(")")
}
case "boolean", "value", "null", "integer":
g.u(typ)
default:
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
}
return nil
}
// dependencies returns a sequence of all child dependencies of the schema in
// post-order.
//
// The first value is the id/pointer to the dependency, and the second value
// is the schema.
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
return func(yield func(string, *jsonschema.Schema) bool) {
for i, p := range s.Properties {
id := fmt.Sprintf("%s_%d", id, i)
for did, d := range dependencies(id, p) {
if !yield(did, d) {
return
}
}
if !yield(id, p) {
return
}
}
for i, p := range s.PrefixItems {
id := fmt.Sprintf("tuple_%d", i)
for did, d := range dependencies(id, p) {
id := fmt.Sprintf("%s_%s", id, did)
if !yield(id, d) {
return
}
}
if !yield(id, p) {
return
}
}
if s.Items != nil {
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
for did, d := range dependencies(id, s.Items) {
if !yield(did, d) {
return
}
}
if !yield(id, s.Items) {
return
}
}
}
}
type builder struct {
b bytes.Buffer
pad int
rules int
items int
}
// define terminates the current rule, if any, and then either starts a new
// rule or does nothing else if the name is empty.
func (b *builder) define(name string) {
if b.rules > 0 {
b.b.WriteString(";\n")
}
if name == "" {
return
}
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
b.b.WriteString(" ::=")
b.rules++
b.items = 0
}
// quote appends a terminal to the current rule.
func (b *builder) q(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(strconv.Quote(s))
}
// u appends a non-terminal to the current rule.
func (b *builder) u(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(s)
}
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
if s.Minimum == 0 && s.Maximum == 0 {
b.u("TODO")
} else {
b.u("number")
}
}

75
grammar/grammar_test.go Normal file
View File

@@ -0,0 +1,75 @@
package grammar
import (
"bufio"
"cmp"
"iter"
"strings"
"testing"
_ "embed"
"github.com/ollama/ollama/grammar/internal/diff"
)
func TestFromSchema(t *testing.T) {
for tt := range testCases(t) {
t.Run(tt.name, func(t *testing.T) {
g, err := FromSchema(nil, []byte(tt.schema))
if err != nil {
t.Fatalf("FromSchema: %v", err)
}
got := string(g)
got = strings.TrimPrefix(got, jsonTerms)
if got != tt.want {
t.Logf("schema:\n%s", tt.schema)
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
}
})
}
}
type testCase struct {
name string
schema string
want string
}
//go:embed testdata/schemas.txt
var tests string
func testCases(t testing.TB) iter.Seq[testCase] {
t.Helper()
return func(yield func(testCase) bool) {
t.Helper()
sc := bufio.NewScanner(strings.NewReader(tests))
name := ""
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" {
name = ""
continue
}
if line[0] == '#' {
name = cmp.Or(name, strings.TrimSpace(line[1:]))
continue
}
s := sc.Text()
g := ""
for sc.Scan() {
line = strings.TrimSpace(sc.Text())
if line == "" || line[0] == '#' {
break
}
g += sc.Text() + "\n"
}
if !yield(testCase{name, s, g}) {
return
}
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
}
if err := sc.Err(); err != nil {
t.Fatalf("error reading tests: %v", err)
}
}
}

View File

@@ -0,0 +1,261 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"fmt"
"sort"
"strings"
)
// A pair is a pair of values tracked for both the x and y side of a diff.
// It is typically a pair of line indexes.
type pair struct{ x, y int }
// Diff returns an anchored diff of the two texts old and new
// in the “unified diff” format. If old and new are identical,
// Diff returns a nil slice (no output).
//
// Unix diff implementations typically look for a diff with
// the smallest number of lines inserted and removed,
// which can in the worst case take time quadratic in the
// number of lines in the texts. As a result, many implementations
// either can be made to run for a long time or cut off the search
// after a predetermined amount of work.
//
// In contrast, this implementation looks for a diff with the
// smallest number of “unique” lines inserted and removed,
// where unique means a line that appears just once in both old and new.
// We call this an “anchored diff” because the unique lines anchor
// the chosen matching regions. An anchored diff is usually clearer
// than a standard diff, because the algorithm does not try to
// reuse unrelated blank lines or closing braces.
// The algorithm also guarantees to run in O(n log n) time
// instead of the standard O(n²) time.
//
// Some systems call this approach a “patience diff,” named for
// the “patience sorting” algorithm, itself named for a solitaire card game.
// We avoid that name for two reasons. First, the name has been used
// for a few different variants of the algorithm, so it is imprecise.
// Second, the name is frequently interpreted as meaning that you have
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
// when in fact the algorithm is faster than the standard one.
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
if bytes.Equal(old, new) {
return nil
}
x := lines(old)
y := lines(new)
// Print diff header.
var out bytes.Buffer
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
fmt.Fprintf(&out, "--- %s\n", oldName)
fmt.Fprintf(&out, "+++ %s\n", newName)
// Loop over matches to consider,
// expanding each match to include surrounding lines,
// and then printing diff chunks.
// To avoid setup/teardown cases outside the loop,
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
// in the sequence of matches.
var (
done pair // printed up to x[:done.x] and y[:done.y]
chunk pair // start lines of current chunk
count pair // number of lines from each side in current chunk
ctext []string // lines for current chunk
)
for _, m := range tgs(x, y) {
if m.x < done.x {
// Already handled scanning forward from earlier match.
continue
}
// Expand matching lines as far as possible,
// establishing that x[start.x:end.x] == y[start.y:end.y].
// Note that on the first (or last) iteration we may (or definitely do)
// have an empty match: start.x==end.x and start.y==end.y.
start := m
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
start.x--
start.y--
}
end := m
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
end.x++
end.y++
}
// Emit the mismatched lines before start into this chunk.
// (No effect on first sentinel iteration, when start = {0,0}.)
for _, s := range x[done.x:start.x] {
ctext = append(ctext, "-"+s)
count.x++
}
for _, s := range y[done.y:start.y] {
ctext = append(ctext, "+"+s)
count.y++
}
// If we're not at EOF and have too few common lines,
// the chunk includes all the common lines and continues.
const C = 3 // number of context lines
if (end.x < len(x) || end.y < len(y)) &&
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
for _, s := range x[start.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
continue
}
// End chunk with common lines for context.
if len(ctext) > 0 {
n := end.x - start.x
if n > C {
n = C
}
for _, s := range x[start.x : start.x+n] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = pair{start.x + n, start.y + n}
// Format and emit chunk.
// Convert line numbers to 1-indexed.
// Special case: empty file shows up as 0,0 not 1,0.
if count.x > 0 {
chunk.x++
}
if count.y > 0 {
chunk.y++
}
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
for _, s := range ctext {
out.WriteString(s)
}
count.x = 0
count.y = 0
ctext = ctext[:0]
}
// If we reached EOF, we're done.
if end.x >= len(x) && end.y >= len(y) {
break
}
// Otherwise start a new chunk.
chunk = pair{end.x - C, end.y - C}
for _, s := range x[chunk.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
}
return out.Bytes()
}
// lines returns the lines in the file x, including newlines.
// If the file does not end in a newline, one is supplied
// along with a warning about the missing newline.
func lines(x []byte) []string {
l := strings.SplitAfter(string(x), "\n")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
} else {
// Treat last line as having a message about the missing newline attached,
// using the same text as BSD/GNU diff (including the leading backslash).
l[len(l)-1] += "\n\\ No newline at end of file\n"
}
return l
}
// tgs returns the pairs of indexes of the longest common subsequence
// of unique lines in x and y, where a unique line is one that appears
// once in x and once in y.
//
// The longest common subsequence algorithm is as described in
// Thomas G. Szymanski, “A Special Case of the Maximal Common
// Subsequence Problem,” Princeton TR #170 (January 1975),
// available at https://research.swtch.com/tgs170.pdf.
func tgs(x, y []string) []pair {
// Count the number of times each string appears in a and b.
// We only care about 0, 1, many, counted as 0, -1, -2
// for the x side and 0, -4, -8 for the y side.
// Using negative numbers now lets us distinguish positive line numbers later.
m := make(map[string]int)
for _, s := range x {
if c := m[s]; c > -2 {
m[s] = c - 1
}
}
for _, s := range y {
if c := m[s]; c > -8 {
m[s] = c - 4
}
}
// Now unique strings can be identified by m[s] = -1+-4.
//
// Gather the indexes of those strings in x and y, building:
// xi[i] = increasing indexes of unique strings in x.
// yi[i] = increasing indexes of unique strings in y.
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
var xi, yi, inv []int
for i, s := range y {
if m[s] == -1+-4 {
m[s] = len(yi)
yi = append(yi, i)
}
}
for i, s := range x {
if j, ok := m[s]; ok && j >= 0 {
xi = append(xi, i)
inv = append(inv, j)
}
}
// Apply Algorithm A from Szymanski's paper.
// In those terms, A = J = inv and B = [0, n).
// We add sentinel pairs {0,0}, and {len(x),len(y)}
// to the returned sequence, to help the processing loop.
J := inv
n := len(xi)
T := make([]int, n)
L := make([]int, n)
for i := range T {
T[i] = n + 1
}
for i := range n {
k := sort.Search(n, func(k int) bool {
return T[k] >= J[i]
})
T[k] = J[i]
L[i] = k + 1
}
k := 0
for _, v := range L {
if k < v {
k = v
}
}
seq := make([]pair, 2+k)
seq[1+k] = pair{len(x), len(y)} // sentinel at end
lastj := n
for i := n - 1; i >= 0; i-- {
if L[i] == k && J[i] < lastj {
seq[k] = pair{xi[i], yi[J[i]]}
k--
}
}
seq[0] = pair{0, 0} // sentinel at start
return seq
}

View File

@@ -0,0 +1,44 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"path/filepath"
"testing"
"golang.org/x/tools/txtar"
)
func clean(text []byte) []byte {
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
text = bytes.TrimSuffix(text, []byte("^D\n"))
return text
}
func Test(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txt")
if len(files) == 0 {
t.Fatalf("no testdata")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatal(err)
}
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
t.Fatalf("%s: want three files, third named \"diff\"", file)
}
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
want := clean(a.Files[2].Data)
if !bytes.Equal(diffs, want) {
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
diffs, want, Diff("have", diffs, "want", want))
}
})
}
}

View File

@@ -0,0 +1,13 @@
-- old --
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -0,0 +1,3 @@
+a
+b
+c

View File

@@ -0,0 +1,13 @@
-- old --
a
b
c
-- new --
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +0,0 @@
-a
-b
-c

View File

@@ -0,0 +1,35 @@
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
-- old --
a
b
c
d
e
f
g
-- new --
w
a
b
x
y
z
e
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,7 @@
+w
a
b
-c
-d
+x
+y
+z
e
-f
-g

40
grammar/internal/diff/testdata/dups.txt vendored Normal file
View File

@@ -0,0 +1,40 @@
-- old --
a
b
c
d
e
f
-- new --
a
B
C
d
e
f
-- diff --
diff old new
--- old
+++ new
@@ -1,8 +1,8 @@
a
$
-b
-
-c
+B
+
+C
$
d
$

38
grammar/internal/diff/testdata/end.txt vendored Normal file
View File

@@ -0,0 +1,38 @@
-- old --
1
2
3
4
5
6
7
eight
nine
ten
eleven
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -5,7 +5,6 @@
5
6
7
-eight
-nine
-ten
-eleven
+8
+9
+10

View File

@@ -0,0 +1,9 @@
-- old --
a
b
c^D
-- new --
a
b
c^D
-- diff --

18
grammar/internal/diff/testdata/eof1.txt vendored Normal file
View File

@@ -0,0 +1,18 @@
-- old --
a
b
c
-- new --
a
b
c^D
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
+c
\ No newline at end of file

18
grammar/internal/diff/testdata/eof2.txt vendored Normal file
View File

@@ -0,0 +1,18 @@
-- old --
a
b
c^D
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
\ No newline at end of file
+c

62
grammar/internal/diff/testdata/long.txt vendored Normal file
View File

@@ -0,0 +1,62 @@
-- old --
1
2
3
4
5
6
7
8
9
10
11
12
13
14
14½
15
16
17
18
19
20
-- new --
1
2
3
4
5
6
8
9
10
11
12
13
14
17
18
19
20
-- diff --
diff old new
--- old
+++ new
@@ -4,7 +4,6 @@
4
5
6
-7
8
9
10
@@ -12,9 +11,6 @@
12
13
14
-14½
-15
-16
17
18
19

View File

@@ -0,0 +1,5 @@
-- old --
hello world
-- new --
hello world
-- diff --

View File

@@ -0,0 +1,34 @@
-- old --
e
pi
4
5
6
7
8
9
10
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -1,5 +1,6 @@
-e
-pi
+1
+2
+3
4
5
6

40
grammar/internal/diff/testdata/triv.txt vendored Normal file
View File

@@ -0,0 +1,40 @@
Another example from Hunt and McIlroy,
“An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
Anchored diff gives up on finding anything,
since there are no unique lines.
-- old --
a
b
c
a
b
b
a
-- new --
c
a
b
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,6 @@
-a
-b
-c
-a
-b
-b
-a
+c
+a
+b
+a
+b
+c

View File

@@ -0,0 +1,171 @@
package jsonschema
import (
"bytes"
"encoding/json"
"errors"
)
// Schema holds a JSON schema.
type Schema struct {
// Name is the name of the property. For the parent/root property, this
// is "root". For child properties, this is the name of the property.
Name string `json:"-"`
// Type is the type of the property.
//
// TODO: Union types (e.g. make this a []string).
Type string
// PrefixItems is a list of schemas for each item in a tuple. By
// default, the tuple is "closed." unless Items is set to true or a
// valid Schema.
PrefixItems []*Schema
// Items is the schema for each item in a list.
//
// If it is missing, or its JSON value is "null" or "false", it is nil.
// If the JSON value is "true", it is set to the empty Schema. If the
// JSON value is an object, it will be decoded as a Schema.
Items *Schema
// MinItems specifies the minimum number of items allowed in a list.
MinItems int
// MaxItems specifies the maximum number of items allowed in a list.
MaxItems int
// Properties is the schema for each property of an object.
Properties []*Schema
// Format is the format of the property. This is used to validate the
// property against a specific format.
//
// It is the callers responsibility to validate the property against
// the format.
Format string
// Minimum specifies the minimum value for numeric properties.
Minimum float64
// Maximum specifies the maximum value for numeric properties.
Maximum float64
// Enum is a list of valid values for the property.
Enum []json.RawMessage
}
func (s *Schema) UnmarshalJSON(data []byte) error {
type S Schema
w := struct {
Properties props
Items items
*S
}{
S: (*S)(s),
}
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Items.set {
s.Items = &w.Items.Schema
}
s.Properties = w.Properties
return nil
}
type items struct {
Schema
set bool
}
func (s *items) UnmarshalJSON(data []byte) error {
switch b := data[0]; b {
case 't':
*s = items{set: true}
case '{':
type I items
if err := json.Unmarshal(data, (*I)(s)); err != nil {
return err
}
s.set = true
case 'n', 'f':
default:
return errors.New("invalid Items")
}
return nil
}
// EffectiveType returns the effective type of the schema. If the Type field is
// not empty, it is returned; otherwise:
//
// - If the schema has both Properties and Items, it returns an empty string.
// - If the schema has Properties, it returns "object".
// - If the schema has Items, it returns "array".
// - If the schema has neither Properties nor Items, it returns "value".
//
// The returned string is never empty.
func (d *Schema) EffectiveType() string {
if d.Type == "" {
if len(d.Properties) > 0 {
return "object"
}
if len(d.PrefixItems) > 0 || d.Items != nil {
return "array"
}
return "value"
}
return d.Type
}
// props is an ordered list of properties. The order of the properties
// is the order in which they were defined in the schema.
type props []*Schema
var _ json.Unmarshaler = (*props)(nil)
func (v *props) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] != '{' {
return errors.New("expected object")
}
d := json.NewDecoder(bytes.NewReader(data))
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
// llama.cpp, ignore unknown fields, which could be lead to unexpected
// behavior for clients of this package, since they may not be aware
// that "additionalFields", "itemsPrefix", etc, are being ignored.
//
// For now, just do what llama.cpp does.
t, err := d.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("expected object")
}
for d.More() {
// Use the first token (map key) as the property name, then
// decode the rest of the object fields into a Schema and
// append.
t, err := d.Token()
if err != nil {
return err
}
if t == json.Delim('}') {
return nil
}
s := &Schema{
Name: t.(string),
}
if err := d.Decode(s); err != nil {
return err
}
*v = append(*v, s)
}
return nil
}

View File

@@ -0,0 +1,104 @@
package jsonschema
import (
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
const testSchemaBasic = `
{
"properties": {
"tupleClosedEmpty": { "prefixItems": [] },
"tupleClosedMissing": { "prefixItems": [{}] },
"tupleClosedNull": { "prefixItems": [{}], "items": null },
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
"array": { "items": {"type": "number"} },
"null": { "type": "null" },
"string": { "type": "string" },
"boolean": { "type": "boolean" }
}
}
`
func TestSchemaUnmarshal(t *testing.T) {
var got *Schema
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
want := &Schema{
Properties: []*Schema{
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
{Name: "array", Items: &Schema{Type: "number"}},
{Name: "null", Type: "null"},
{Name: "string", Type: "string"},
{Name: "boolean", Type: "boolean"},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want, +got)\n%s", diff)
}
}
func TestEffectiveType(t *testing.T) {
const schema = `
{"properties": {
"o": {"type": "object"},
"a": {"type": "array"},
"n": {"type": "number"},
"s": {"type": "string"},
"z": {"type": "null"},
"b": {"type": "boolean"},
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
"t1": {"items": {"type": "number"}, "maxItems": 3},
"v": {"maxItems": 3}
}}
`
var s *Schema
if err := json.Unmarshal([]byte(schema), &s); err != nil {
t.Fatalf("json.Unmarshal: %v", err)
}
var got []string
for _, p := range s.Properties {
got = append(got, p.EffectiveType())
}
want := strings.Fields(`
object
array
number
string
null
boolean
array
array
value
`)
if !reflect.DeepEqual(want, got) {
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
}
}

76
grammar/testdata/schemas.txt vendored Normal file
View File

@@ -0,0 +1,76 @@
# This file holds tests for JSON schema to EBNF grammar conversions.
#
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
# name the tests number in the file (e.g. "#0", "#1", etc.)
#
# Blank lines signify the end or start of a new test. Comments can be added
# anywhere in the file, but they must be preceded by a '#' character and start at
# the beginning of the line.
# default
{}
root ::= value;
{"properties": {}}
root ::= value;
# array
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
root_0_tuple_0 ::= string;
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root ::= "{" "a" ":" root_0 "}";
# array with nested array
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
root_tuple_0_tuple_0 ::= string;
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
root ::= "[" ( root_tuple_0 )* "]";
# object
{"properties": {"e": {}}}
root_0 ::= value;
root ::= "{" "e" ":" root_0 "}";
# object with nested object
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
root_0_0 ::= value;
root_0 ::= "{" "e" ":" root_0_0 "}";
root ::= "{" "o" ":" root_0 "}";
# boolean
{"type": "boolean"}
root ::= boolean;
# number
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
root_0 ::= number;
root ::= "{" "n" ":" root_0 "}";
# string
{"type": "string"}
root ::= string;
# string with enum
{"type": "string", "enum": ["a", "b", "c"]}
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
# spaces in key
{"properties": {"a b": {}}}
root_0 ::= value;
root ::= "{" "a b" ":" root_0 "}";
# issue7978
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
root_0_tuple_0_0 ::= string;
root_0_tuple_0_1 ::= string;
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root_1 ::= string;
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
# !! # special characters in key
# !! {"properties": {"a!b": {}}}
# !! !invalid character '!' in key
# !!

View File

@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}
func TestIntegrationSplitBatch(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -43,8 +43,13 @@ type Cache interface {
// ** cache management **
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
@@ -52,7 +57,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error
StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
opts CausalOptions
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
}
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
}
c.curCellRange = newRange()
for i, pos := range opts.Positions {
seq := opts.Sequences[i]
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
}
func roundDown(length, pad int) int {
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v).
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
}
}
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)

View File

@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF32, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "SlidingWindow",
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}

View File

@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image
if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
return nil

View File

@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
for _, cache := range c.caches {
cache.Init(backend, dtype, capacity)
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
}
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, opts)
err := cache.StartForward(ctx, batch)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range opts.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
}
}
return err

View File

@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{

View File

@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,

View File

@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:

View File

@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.blk.") == std::string::npos;
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -0,0 +1,113 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Patrick Devine <patrick@infrahq.com>
Date: Fri, 14 Mar 2025 16:33:23 -0700
Subject: [PATCH] gemma3 quantization
---
src/llama-arch.cpp | 19 +++++++++++++++++++
src/llama-arch.h | 1 +
src/llama-model.cpp | 7 +++++++
src/llama-quant.cpp | 9 +++++++++
4 files changed, 36 insertions(+)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index b6f20286..b443fcd3 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index ec742224..aad92a5d 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ab1a07d1..70183041 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 6eb1da08..d2f3a510 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ // don't quantize vision stuff
+ quantize &= name.find("v.blk.") == std::string::npos;
+
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
+
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
}
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),

View File

@@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/grammar"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
@@ -402,7 +403,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
}
slog.Info("starting llama server", "cmd", s.cmd.String())
slog.Info("starting llama server", "cmd", s.cmd)
if envconfig.Debug() {
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
@@ -470,7 +471,7 @@ const ( // iota is reset to 0
ServerStatusError
)
func (s ServerStatus) ToString() string {
func (s ServerStatus) String() string {
switch s {
case ServerStatusReady:
return "llm server ready"
@@ -485,12 +486,9 @@ func (s ServerStatus) ToString() string {
}
}
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"`
Progress float32 `json:"progress"`
type ServerStatusResponse struct {
Status ServerStatus `json:"status"`
Progress float32 `json:"progress"`
}
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -502,7 +500,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
}
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
}
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
@@ -527,21 +525,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
return ServerStatusError, fmt.Errorf("read health request: %w", err)
}
var status ServerStatusResp
if err := json.Unmarshal(body, &status); err != nil {
var ssr ServerStatusResponse
if err := json.Unmarshal(body, &ssr); err != nil {
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
}
switch status.Status {
case "ok":
return ServerStatusReady, nil
case "no slot available":
return ServerStatusNoSlotsAvailable, nil
case "loading model":
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil
switch ssr.Status {
case ServerStatusLoadingModel:
s.loadProgress = ssr.Progress
return ssr.Status, nil
case ServerStatusReady, ServerStatusNoSlotsAvailable:
return ssr.Status, nil
default:
return ServerStatusError, fmt.Errorf("server error: %+v", status)
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
}
}
@@ -616,7 +612,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady {
// Only log on status changes
slog.Info("waiting for server to become available", "status", status.ToString())
slog.Info("waiting for server to become available", "status", status)
}
switch status {
case ServerStatusReady:
@@ -630,7 +626,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
slog.Debug("model load completed, waiting for server to become available", "status", status)
stallTimer = time.Now().Add(stallDuration)
fullyLoaded = true
}
@@ -671,63 +667,26 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"`
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
}
type CompletionResponse struct {
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Content string `json:"content"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
"n_predict": req.Options.NumPredict,
"n_keep": req.Options.NumKeep,
"main_gpu": req.Options.MainGPU,
"temperature": req.Options.Temperature,
"top_k": req.Options.TopK,
"top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN,
"repeat_penalty": req.Options.RepeatPenalty,
"presence_penalty": req.Options.PresencePenalty,
"frequency_penalty": req.Options.FrequencyPenalty,
"mirostat": req.Options.Mirostat,
"mirostat_tau": req.Options.MirostatTau,
"mirostat_eta": req.Options.MirostatEta,
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"cache_prompt": true,
}
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -735,21 +694,31 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set".
break
case `"json"`:
request["grammar"] = grammarJSON
req.Grammar = grammarJSON
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
}
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
g, err := grammar.FromSchema(nil, req.Format)
if err != nil {
return fmt.Errorf("invalid JSON schema in format: %w", err)
}
request["grammar"] = string(g)
req.Grammar = string(g)
}
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -764,13 +733,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status.ToString())
return fmt.Errorf("unexpected server status: %s", status)
}
// Handling JSON marshaling with special characters unescaped.
@@ -778,7 +746,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
@@ -829,7 +797,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
evt = line
}
var c completion
var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
@@ -853,20 +821,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
})
}
if c.Stop {
doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
if c.Done {
fn(c)
return nil
}
}
@@ -914,7 +870,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
return nil, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1059,12 +1015,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
}
return 0
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return dur
}

View File

@@ -2,6 +2,7 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@@ -60,6 +61,10 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -76,9 +81,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f
}
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(f, params)
return backend(ctx, f, params)
}
return nil, fmt.Errorf("unsupported backend")

View File

@@ -9,15 +9,17 @@ package ggml
import "C"
import (
"errors"
"context"
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
"unicode"
"unsafe"
@@ -58,7 +60,7 @@ type Backend struct {
maxGraphNodes int
}
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
}
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] {
g.Go(func() error {
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" {
target = t.Name
}
@@ -312,23 +318,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
tts[i] = tt
}
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
return err
}
if n != len(bts) {
return errors.New("short read")
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil
})
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if g.Wait() != nil {
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
return nil, err
}
@@ -371,7 +398,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
true,
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
),
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],

View File

@@ -1,5 +1,7 @@
package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
@@ -15,6 +17,12 @@ type Input struct {
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
@@ -27,11 +35,24 @@ type MultimodalIndex struct {
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
}

View File

@@ -1,6 +1,7 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error)
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
@@ -60,7 +61,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
PostTokenize([]input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r, params)
b, err := ml.NewBackend(ctx, r, params)
if err != nil {
return nil, err
}
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(opts.Positions) < 1 {
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, opts)
err := cache.StartForward(ctx, batch)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, opts)
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}

View File

@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}

View File

@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
}
func init() {

View File

@@ -2,10 +2,9 @@ package gemma3
import (
"bytes"
"encoding/binary"
"hash/fnv"
"image"
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil
}
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
fnvHash := fnv.New64a()
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
result = append(result, imageInputs...)
// add image embeddings
inputMultimodal := inp.Multimodal.(ml.Tensor)
for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
// add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
@@ -153,23 +139,18 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, imageDst)
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
}
}
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)

View File

@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)

View File

@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@@ -135,29 +135,27 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 {
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -32,6 +32,7 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocab() *Vocabulary
}
type Vocabulary struct {

View File

@@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Vocab() *Vocabulary {
return spm.vocab
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {

View File

@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common"
)
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
@@ -229,7 +230,7 @@ type Server struct {
image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
status llm.ServerStatus
// current progress on loading the model
progress float32
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
// Extract options from the CompletionRequest
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TypicalP: req.Options.TypicalP,
Temp: req.Options.Temperature,
RepeatLastN: req.Options.RepeatLastN,
PenaltyRepeat: req.Options.RepeatPenalty,
PenaltyFreq: req.Options.FrequencyPenalty,
PenaltyPresent: req.Options.PresencePenalty,
Mirostat: req.Options.Mirostat,
MirostatTau: req.Options.MirostatTau,
MirostatEta: req.Options.MirostatEta,
Seed: uint32(req.Options.Seed),
Grammar: req.Grammar,
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numDecoded,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
panic(err)
}
s.status = ServerStatusReady
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -937,7 +852,7 @@ func Execute(args []string) error {
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel,
status: llm.ServerStatusLoadingModel,
}
var tensorSplitFloats []float32

View File

@@ -31,8 +31,10 @@ type InputCache struct {
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
}
return &InputCache{
numCtx: kvSize / int32(numSlots),
numCtx: numCtx,
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,
@@ -89,7 +91,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -107,10 +109,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
})
}
}
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -24,6 +24,7 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -33,10 +34,14 @@ import (
_ "github.com/ollama/ollama/model/models"
)
type contextList struct {
list []ml.Context
}
type Sequence struct {
// ctx for allocating tensors that last the lifetime of the sequence, such as
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctx ml.Context
ctxs *contextList
// batch index
iBatch int
@@ -94,13 +99,12 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(ctx, prompt, images)
inputs, ctxs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@@ -111,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = int32(len(inputs))
}
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -126,7 +133,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctx: ctx,
ctxs: ctxs,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@@ -145,7 +152,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
var inputs []input.Input
var parts []string
var matches [][]string
@@ -160,12 +167,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
parts = []string{prompt}
}
var contexts contextList
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
for _, ctx := range ctxs {
ctx.Close()
}
}, contexts.list)
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, err
return nil, nil, err
}
for _, t := range tokens {
@@ -185,12 +199,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
}
if imageIndex < 0 {
return nil, fmt.Errorf("invalid image index: %d", n)
return nil, nil, fmt.Errorf("invalid image index: %d", n)
}
ctx := s.model.Backend().NewContext()
contexts.list = append(contexts.list, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, err
return nil, nil, err
}
s.multimodalHash.Reset()
@@ -204,13 +220,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return nil, err
return nil, nil, err
}
}
return inputs, nil
return inputs, &contexts, nil
}
type Server struct {
@@ -222,7 +238,7 @@ type Server struct {
model model.Model
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
status llm.ServerStatus
// current progress on loading the model
progress float32
@@ -305,7 +321,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -333,7 +348,8 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var options input.Options
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
if seq == nil {
@@ -351,33 +367,46 @@ func (s *Server) processBatch() error {
seq.cache.Inputs = []input.Input{}
}
batchSize := s.batchSize
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if j >= s.batchSize {
if len(seq.pendingInputs)+minBatch > batchSize {
break
}
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
}
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
}
seq.iBatch = len(options.Outputs)
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -385,14 +414,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if len(options.Inputs) == 0 {
if len(batchInputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, options)
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -432,13 +461,27 @@ func (s *Server) processBatch() error {
}
// sample a token
vocabSize := len(logits) / len(options.Outputs)
vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
}
if seq.sampler.JSONSampler != nil {
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
if seq.sampler.PythonSampler != nil {
err = seq.sampler.PythonSampler.UpdateState(token)
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
@@ -501,75 +544,18 @@ func (s *Server) processBatch() error {
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -590,19 +576,37 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
// if err != nil {
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
// return
// }
// jsonSampler = nil
pythonSampler := &sample.PythonSampler{}
functions := []sample.PythonFunction{
{
Name: "add_two_strings",
Args: []string{"s1", "s2"},
Types: []string{"string", "string"},
},
}
pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
nil,
pythonSampler,
// nil,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
})
@@ -625,7 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -652,7 +656,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -663,15 +667,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numPredicted,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -682,43 +688,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -737,6 +710,7 @@ func (m *multiLPath) String() string {
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@@ -746,7 +720,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath, params)
s.model, err = model.New(ctx, mpath, params)
if err != nil {
panic(err)
}
@@ -758,7 +732,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented")
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil {
panic(err)
}
@@ -772,7 +746,7 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
s.status = ServerStatusReady
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -824,7 +798,7 @@ func Execute(args []string) error {
server := &Server{
batchSize: *batchSize,
status: ServerStatusLoadingModel,
status: llm.ServerStatusLoadingModel,
}
// TODO(jessegross): Parameters that need to be implemented:
@@ -842,6 +816,9 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
@@ -850,13 +827,13 @@ func Execute(args []string) error {
}
server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)

53
sample/gtf.go Normal file
View File

@@ -0,0 +1,53 @@
package sample
var DefaultGrammar = map[string]string{
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
"null": `"null"`,
"object": `"{" (kv ("," kv)*)? "}"`,
"array": `"[" (value ("," value)*)? "]"`,
"kv": `string ":" value`,
"integer": `"0" | [1-9] [0-9]*`,
"number": `"-"? integer frac? exp?`,
"frac": `"." [0-9]+`,
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
"string": `"\"" char* "\""`,
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
"char": `[^"\\] | escape`,
"space": `(" " | "\t" | "\n" | "\r")*`,
"hex": `[0-9] | [a-f] | [A-F]`,
"boolean": `"true" | "false"`,
"value": `object | array | string | number | boolean | "null"`,
}
const jsonString = `object | array`
type StateMachine struct {
states map[rune]State
}
type State struct {
NextStates []string
// bitmask?
Mask []bool
IsTerminal bool
}
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
states := make(map[rune]State)
var cumu string
flag := false
for _, r := range startRule {
if r == '"' {
flag = !flag
}
if flag {
cumu += string(r)
}
}
sm := &StateMachine{
states: states,
}
return sm
}

138
sample/gtf_test.go Normal file
View File

@@ -0,0 +1,138 @@
package sample
import (
"testing"
)
func TestGrammarParsing(t *testing.T) {
tests := []struct {
name string
grammar map[string]string
startRule string
input string
want bool
}{
{
name: "simple object",
grammar: map[string]string{
"object": `"{" "}"`,
},
startRule: "object",
input: "{}",
want: true,
},
{
name: "simple array",
grammar: map[string]string{
"array": `"[" "]"`,
},
startRule: "array",
input: "[]",
want: true,
},
{
name: "character class",
grammar: map[string]string{
"digit": `[0-9]`,
},
startRule: "digit",
input: "5",
want: true,
},
{
name: "alternation",
grammar: map[string]string{
"bool": `"true" | "false"`,
},
startRule: "bool",
input: "true",
want: true,
},
{
name: "repetition",
grammar: map[string]string{
"digits": `[0-9]+`,
},
startRule: "digits",
input: "123",
want: true,
},
{
name: "nested rules",
grammar: map[string]string{
"value": `object | array`,
"object": `"{" "}"`,
"array": `"[" "]"`,
},
startRule: "value",
input: "{}",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.grammar)
machine, err := parser.Parse(tt.startRule)
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
got, err := matcher.Match(tt.input)
if err != nil {
t.Fatalf("Match() error = %v", err)
}
if got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestJSONGrammar(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{"empty object", "{}", true},
{"empty array", "[]", true},
{"simple string", `"hello"`, true},
{"simple number", "123", true},
{"simple boolean", "true", true},
{"simple null", "null", true},
{"object with string", `{"key": "value"}`, true},
{"array with numbers", "[1, 2, 3]", true},
{"nested object", `{"obj": {"key": "value"}}`, true},
{"nested array", `[1, [2, 3], 4]`, true},
{"invalid object", "{", false},
{"invalid array", "[1, 2", false},
{"invalid string", `"hello`, false},
}
parser := NewParser(DefaultGrammar)
machine, err := parser.Parse("value")
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := matcher.Match(tt.input)
if tt.want {
if err != nil {
t.Errorf("Match() error = %v", err)
}
if !got {
t.Errorf("Match() = false, want true")
}
} else {
if err == nil && got {
t.Errorf("Match() = true, want false")
}
}
})
}
}

160
sample/json_types.go Normal file
View File

@@ -0,0 +1,160 @@
package sample
import (
"fmt"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateInStructuredKey
StateInStructuredValue
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInColon
StateInComma
StateInTab
StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace
StateInList
StateInListComma
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
StateInListStartJSON
)
var JSONStates = []JSONState{
StateStart,
StateInObject,
StateInObjectKey,
StateInStructuredKey,
StateInStructuredValue,
StateNewline,
StateTab,
StateSpace,
StateInString,
StateInInt,
StateInFloat,
StateInBool,
StateInNull,
StateInColon,
StateInComma,
StateInTab,
StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace,
StateInListStartJSON,
StateInList,
StateInListComma,
StateInValue,
StateInValueEnd,
StateInListEnd,
StateInListObjectEnd,
StateInNewline,
StateInNumber,
StateInNumberEnd,
StateInStringEnd,
StateInObjectKeyEnd,
StateTerminate,
StateInObjectEnd,
StateTransitioningToTerminate,
}
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInStructuredKey:
return "StateInStructuredKey"
case StateInStructuredValue:
return "StateInStructuredValue"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpaceToValue:
return "StateInSpaceToValue"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateInValue:
return "StateInValue"
case StateInValueEnd:
return "StateInValueEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateTransitioningToTerminate:
return "StateTransitioningToTerminate"
case StateInListStartJSON:
return "StateInListStartJSON"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}

327
sample/pushdown_automata.go Normal file
View File

@@ -0,0 +1,327 @@
package sample
import (
"fmt"
"slices"
"github.com/ollama/ollama/model"
)
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var (
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
)
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDA struct {
State JSONState
TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDA
}
func NewPDANode(state JSONState) *PDA {
return &PDA{
State: state,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
}
type PDAGraphBuilder struct {
proc model.TextProcessor
decodedToks []string
stateToNodeMap map[JSONState]*PDA
tokenToStatesMap map[int32][]JSONState
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state)
}
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
// TODO: update naming here - and revisit values
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
// new line end value
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
// Leads to a value
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
// Values
// string node
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// early end
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node
for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
}
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma
// should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// todo: review this space transition
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
// space end value
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
}
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
// TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
}
func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap {
// if node.State == StateInObjectKey {
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
// fmt.Println("copying string mask to object key mask")
// }
// }
if err := b.CreateMask(node); err != nil {
return err
}
}
return nil
}
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
// TODO: make can be somewhere else too
b.tokenToStatesMap = make(map[int32][]JSONState)
for i, t := range b.decodedToks {
for _, r := range t {
if r == '"' {
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
}
}
}
return nil
}
// TODO: the mask for obj key and string should be the same?
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range b.decodedToks {
token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}

264
sample/pushdown_runner.go Normal file
View File

@@ -0,0 +1,264 @@
package sample
import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
// TODO: safety in case of invalid json
// TODO: partial JSON matching?
// TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
// TODO: minimize number of fwd passes if there is only one match
// TODO: greedy sample initially and then backtrack if no match
type PushdownSampler struct {
PDAGraphBuilder
curNode *PDA
braceStack []rune
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now()
fmt.Println("--------------------------------")
fmt.Println("PDA sampler")
fmt.Println("--------------------------------")
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
gb := &PDAGraphBuilder{
proc: proc,
decodedToks: decodedToks,
}
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{
curNode: gb.stateToNodeMap[StateStart],
PDAGraphBuilder: *gb,
braceStack: []rune{},
stateCounter: 0,
}, nil
}
// TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
}
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return nil, err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
// Special handling for EOS token in terminate state
if s.curNode.State == StateTerminate {
for _, tokenID := range tokenSlice {
if s.proc.Is(tokenID, model.SpecialEOS) {
return tokenSlice, nil
}
}
}
// flag := -1
// endBraceRunes := []rune{'}', ']'}
for _, r := range mappedString {
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
// // flag = i
// break
// }
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('{') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
if r == rune(']') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('[') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
}
// if flag != -1 {
// tokenSlice = tokenSlice[:flag]
// }
// fmt.Println("flag!", flag)
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return nil, fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = nextNode
fmt.Println("updated curNode state", s.curNode.State)
}
return tokenSlice, nil
}
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
maxLogit := float32(math.Inf(-1))
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float32(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@@ -17,15 +17,34 @@ type token struct {
}
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
JSONSampler *JSONSampler
PythonSampler *PythonSampler
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
var err error
if s.JSONSampler != nil {
logits, err = s.JSONSampler.Apply(logits)
if err != nil {
return -1, err
}
}
if s.PythonSampler != nil {
logits, err = s.PythonSampler.ApplyMask(logits)
if err != nil {
return -1, err
}
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -87,19 +106,13 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
// scale and normalize the tokens in place
temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if len(tokens) == 0 {
return token{}, errors.New("no tokens to sample from")
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
@@ -122,11 +135,14 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return 1
})
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -154,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
JSONSampler: jsonSampler,
PythonSampler: pythonSampler,
}
}

View File

@@ -1,6 +1,7 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
)
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
}
func BenchmarkSample(b *testing.B) {

View File

@@ -0,0 +1,299 @@
package sample
import (
"fmt"
"log/slog"
"runtime"
"time"
"github.com/ollama/ollama/grammar/jsonschema"
"github.com/ollama/ollama/model"
)
type JSONSampler struct {
schema *jsonschema.Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
slog.Info("NewJSONSampler", "schema", schema)
if proc == nil {
return nil, fmt.Errorf("TextProcessor cannot be nil")
}
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
}
if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
// fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema,
propIdx: -1,
propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler,
}
so.schemaToGraph()
// Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start = time.Now()
runtime.ReadMemStats(&m)
before = m.Alloc
// TODO: still messed up
// TODO: recursion use case
// key masks
for _, prop := range so.schema.Properties {
node := so.propToNodeMap[prop.Name]
// propName -> node
curState := node.State
fromNode := node
so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey {
// there is only one edge
for r, toNode := range fromNode.TransitionEdges {
fmt.Println("rune", r, "edge", toNode.State)
so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r)
curState = toNode.State
fmt.Println("next state", curState)
// TODO: theres an extra gen for " right now
fromNode = toNode
}
}
if curState != StateInColon {
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
}
// so.pdaSampler.CreateMask(fromNode)
fromNode = fromNode.TransitionEdges[' ']
so.pdaSampler.CreateMask(fromNode)
curState = fromNode.State
for _, toNode := range fromNode.TransitionEdges {
fmt.Println("toNode", toNode.State)
}
}
// runtime.ReadMemStats(&m)
// after = m.Alloc
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
// fmt.Println("--------------------------------")
return so, nil
}
func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
// TODO: see if we need to connect these to the JSON graph
// each prop is a key
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
// fmt.Println("runeNode created", runeNode.State)
// fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode
}
// point to end of object key node after all chars are done
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// link to value node
// Create a node for the end of the key (after the closing quote)
stringEndNode := &PDA{
State: StateInStructuredKey,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges['"'] = stringEndNode
prevNode = stringEndNode
// Add transition for colon after key
colonNode := &PDA{
State: StateInColon,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[':'] = colonNode
prevNode = colonNode
// Add transition for space after colon
spaceNode := &PDA{
State: StateInSpaceToValue,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[' '] = spaceNode
prevNode = spaceNode
value := prop.Type
switch value {
case "object":
fmt.Println("object under key: ", name)
case "array":
fmt.Println("array under key: ", name)
case "string":
fmt.Println("string under key: ", name)
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
case "number":
fmt.Println("number under key: ", name)
for _, r := range validNumberRunes {
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
}
case "boolean":
fmt.Println("boolean under key: ", name)
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
}
// points to start of the key
s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
// TODO: do values + recursion
}
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State {
// TODO: doesnt account for multi rune case
case StateInObjectKey:
if s.propIdx > len(s.schema.Properties)-1 {
return nil, fmt.Errorf("propIdx out of bounds")
}
// fmt.Println("in object key - structured outputs")
// TODO: this tracking should probably be coming from a stack to track nested objects
// simple case
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
// Will only happen for the last prop - can also be precomputed.
if s.propIdx == len(s.schema.Properties)-1 {
// todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State {
case StateInObjectEnd:
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++
}
}
return s.pdaSampler.Apply(logits)
}
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil {
return nil, err
}
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return tokenSlice, nil
}
switch s.pdaSampler.curNode.State {
case StateInObjectKey:
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
// TODO: this does not work - mike
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
// if err != nil {
// return nil, err
// }
// fmt.Println("str", str)
return tokenSlice, nil
default:
return tokenSlice, nil
}
}

352
sample/structured_python.go Normal file
View File

@@ -0,0 +1,352 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
type PythonState int
const (
PythonStateStart PythonState = iota
StateInFunction
StateInFunctionArgs
StateInFunctionArgsType
StateInFunctionEnd
PStateInString
PStateInStringEnd
PStateInNumber
PStateInList
PStateInListEnd
PStateInDict
PStateInDictEnd
PStateInTuple
PStateInTupleEnd
PStateTerminate
)
func (s PythonState) String() string {
switch s {
case PythonStateStart:
return "PythonStateStart"
case StateInFunction:
return "StateInFunction"
case StateInFunctionArgs:
return "StateInFunctionArgs"
case StateInFunctionArgsType:
return "StateInFunctionArgsType"
case StateInFunctionEnd:
return "StateInFunctionEnd"
case PStateInString:
return "PStateInString"
case PStateInStringEnd:
return "PStateInStringEnd"
case PStateInNumber:
return "PStateInNumber"
case PStateInList:
return "PStateInList"
case PStateInListEnd:
return "PStateInListEnd"
case PStateInDict:
return "PStateInDict"
case PStateInDictEnd:
return "PStateInDictEnd"
case PStateInTuple:
return "PStateInTuple"
case PStateInTupleEnd:
return "PStateInTupleEnd"
case PStateTerminate:
return "PStateTerminate"
default:
return fmt.Sprintf("PythonState(%d)", s)
}
}
var PythonStates = []PythonState{
PythonStateStart,
StateInFunction,
StateInFunctionArgs,
StateInFunctionArgsType,
StateInFunctionEnd,
PStateInString,
PStateInStringEnd,
PStateInNumber,
PStateInList,
PStateInListEnd,
PStateInDict,
PStateInDictEnd,
PStateInTuple,
PStateInTupleEnd,
PStateTerminate,
}
type Node struct {
State PythonState
TransitionEdges map[rune]*Node
MaskTokenIDToNode map[int32]*Node
}
func NewNode(state PythonState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[rune]*Node),
MaskTokenIDToNode: make(map[int32]*Node),
}
}
type PythonFunction struct {
Name string
Args []string
Types []string
}
type PythonSampler struct {
stateToNodes map[PythonState]*Node
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
s.decodedToks = decodedToks
s.BuildGraph()
for _, function := range functions {
prevNode := s.stateToNodes[PythonStateStart]
for _, r := range function.Name {
nextNode := NewNode(StateInFunction)
prevNode.TransitionEdges[r] = nextNode
if err := s.CreateMask(nextNode); err != nil {
return err
}
fmt.Println("prevNode", prevNode.State)
fmt.Printf("transition edge: %q\n", r)
fmt.Println("nextNode", nextNode.State)
prevNode = nextNode
}
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
s.CreateMask(prevNode)
prevNode = s.stateToNodes[StateInFunctionArgs]
for i, arg := range function.Args {
for _, r := range arg {
nextNode := NewNode(StateInFunctionArgs)
prevNode.TransitionEdges[r] = nextNode
s.CreateMask(prevNode)
prevNode = nextNode
}
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// prevNode = s.stateToNodes[StateInFunctionArgs]
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
s.CreateMask(prevNode)
prevNode = prevNode.TransitionEdges['=']
switch function.Types[i] {
case "string":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
s.CreateMask(prevNode.TransitionEdges['"'])
case "number":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
s.CreateMask(prevNode.TransitionEdges['"'])
}
}
}
s.curNode = s.stateToNodes[PythonStateStart]
fmt.Println("curNode", s.curNode.State)
fmt.Println("transition edges", s.curNode.TransitionEdges)
if err := s.CreateMask(s.curNode); err != nil {
return err
}
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
for tokenID, node := range s.curNode.MaskTokenIDToNode {
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
}
return nil
}
func (s *PythonSampler) BuildGraph() error {
s.stateToNodes = make(map[PythonState]*Node)
for _, state := range PythonStates {
s.stateToNodes[state] = NewNode(state)
}
for _, state := range s.stateToNodes {
if err := s.CreateMask(state); err != nil {
return err
}
}
// String
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
}
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
return nil
}
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
if s.curNode.State == PStateTerminate {
logits, err := finish(s, logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
func (s *PythonSampler) UpdateState(token int32) error {
mappedString, err := s.proc.Decode([]int32{token})
if err != nil {
return err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
if s.curNode.State == PStateTerminate {
if s.proc.Is(token, model.SpecialEOS) {
return nil
}
}
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
}
if err := s.CreateMask(s.curNode); err != nil {
return err
}
return nil
}
func (s *PythonSampler) CreateMask(node *Node) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range s.decodedToks {
token := s.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
if curNode.State == StateInFunction {
// fmt.Println("cm curNode", curNode.State)
// fmt.Println("cm token", s.decodedToks[i])
}
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}

View File

@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) []token {
func softmax(ts []token) {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
for i, t := range ts {
sum += t.value
if sum > float32(p) {
ts = ts[:i+1]
return ts
return ts[:i+1]
}
}
return ts
}
// minP limits tokens to those with cumulative probability p
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
maxProb := ts[0].value
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
threshold := maxProb * p
for i, t := range ts {
if t.value < threshold {
return ts[:i]
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
}

View File

@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
got := temperature(toTokens(input), 0.5)
tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got)
compareLogits(t, "temperature(0.5)", want, tokens)
got = temperature(toTokens(input), 1.0)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got)
compareLogits(t, "temperature(1)", want, tokens)
got = temperature(toTokens(input), 0.0)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got)
compareLogits(t, "temperature(0)", want, tokens)
}
func TestSoftmax(t *testing.T) {
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := softmax(toTokens(tt.input))
tokens := toTokens(tt.input)
softmax(tokens)
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got)
compareLogits(t, tt.name, tt.expected, tokens)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range got {
for _, token := range tokens {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
// Test k=5
got := topK(toTokens(input), 5)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
tokens := toTokens(input)
tokens = topK(tokens, 5)
if len(tokens) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got)
compareLogits(t, "topK(3)", want, tokens)
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, 20)
if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
}
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, -1)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, got)
compareLogits(t, "topK(-1)", want, tokens)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, 0)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
}
compareLogits(t, "topK(-1)", want, got)
}
func TestTopP(t *testing.T) {
@@ -153,50 +165,134 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
tokens = softmax(tokens)
softmax(tokens)
tokens = topK(tokens, 20)
// Then apply topP
got := topP(tokens, 0.95)
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
tokens = softmax(tokens)
tokens = topK(tokens, 20)
softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)
tokens = minP(tokens, 1.0)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
}
}
func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
got := minP(tokens, 1.0)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
// Test with normal p value
got = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
}
func BenchmarkTransforms(b *testing.B) {
@@ -231,7 +327,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 10)
tokens = topK(tokensCopy, 10)
}
})
@@ -239,7 +335,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topP(tokensCopy, 0.9)
tokens = topP(tokensCopy, 0.9)
}
})
@@ -247,7 +343,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
minP(tokensCopy, 0.2)
tokens = minP(tokensCopy, 0.2)
}
})
@@ -255,7 +351,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 200000)
tokens = topK(tokensCopy, 200000)
}
})
}

View File

@@ -8,7 +8,7 @@ usage() {
exit 1
}
export VERSION=${VERSION:-$(git describe --tags --dirty)}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'

View File

@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>@<digest>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name)
if err != nil {
return err
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
// Links returns a sequence of link names. The sequence is in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
@@ -414,12 +402,14 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
d Digest
f *os.File
err error
h hash.Hash
w io.Writer // underlying writer; set by creator
n int64
err error
testHookBeforeFinalWrite func(*os.File)
}
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
n, err := w.w.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
@@ -493,10 +487,12 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
d: out,
size: size,
h: sha256.New(),
f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
panic(err) // this should never happen
}
return abs
}

73
server/internal/cache/blob/chunked.go vendored Normal file
View File

@@ -0,0 +1,73 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
)
// Chunk represents a range of bytes in a blob.
type Chunk struct {
Start int64
End int64
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

View File

@@ -63,6 +63,10 @@ func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}

View File

@@ -1,78 +0,0 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@@ -1,65 +0,0 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@@ -19,11 +19,13 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
@@ -35,10 +37,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed"
)
@@ -60,18 +59,18 @@ var (
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
DefaultChunkingThreshold = 64 << 20
)
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
@@ -211,20 +210,13 @@ type Registry struct {
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// Clients that want "unlimited" streams should set this to a large
// number.
// A negative value means no limit.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
Mask string
@@ -266,6 +258,7 @@ func DefaultRegistry() (*Registry, error) {
}
var rc Registry
rc.UserAgent = UserAgent()
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, err
@@ -281,25 +274,35 @@ func DefaultRegistry() (*Registry, error) {
return &rc, nil
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
)
}
func (r *Registry) maxStreams() int {
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
@@ -426,6 +429,22 @@ func canRetry(err error) bool {
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
l *Layer
r io.Reader
update func(l *Layer, n int64, err error)
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(r.l, int64(n), nil)
return
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
@@ -434,15 +453,15 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name)
if err != nil {
return err
}
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
@@ -452,142 +471,105 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
t := traceFromContext(ctx)
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
// TODO(bmizerany): work to remove the need to do this
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx)
for _, l := range layers {
if exists(l) {
t.update(l, 0, nil)
expected += l.Size
}
var received atomic.Int64
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range layers {
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
received.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue
}
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
t.update(l, 0, nil)
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
}
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
// TODO(bmizerany): retry/backoff like below in
// the chunking case
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
received.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) {
defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
})
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
wp := writerPool{size: r.maxChunkSize()}
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take()
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := wp.get()
tw.Reset(ticket)
defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
}
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}
@@ -615,8 +597,6 @@ type Manifest struct {
Config *Layer `json:"config"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given
// digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer {
@@ -643,10 +623,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config *Layer `json:"config"`
Config Layer `json:"config"`
}{
M: M(m),
Config: &Layer{Digest: emptyDigest},
M: M(m),
}
return json.Marshal(v)
}
@@ -736,6 +715,123 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
return m, nil
}
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
//
// Example:
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
//
// The blobURL is the URL to download the chunks from.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
@@ -898,13 +994,6 @@ func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
type publicError struct {
wrapped error
message string
@@ -991,27 +1080,22 @@ func splitExtended(s string) (scheme, name, digest string) {
return scheme, s, digest
}
type writerPool struct {
size int64 // set by the caller
mu sync.Mutex
ws []*bufio.Writer
}
func (p *writerPool) get() *bufio.Writer {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.ws) == 0 {
return bufio.NewWriterSize(nil, int(p.size))
// parseChunk parses a string in the form "start-end" and returns the Chunk.
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
}
w := p.ws[len(p.ws)-1]
p.ws = p.ws[:len(p.ws)-1]
return w
}
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
}
end, err := strconv.ParseInt(endPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
}
if start > end {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
}
return blob.Chunk{Start: start, End: end}, nil
}

View File

@@ -17,14 +17,36 @@ import (
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
@@ -57,21 +79,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
@@ -89,7 +111,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
Transport: recordRoundTripper(upstreamRegistry),
},
}
@@ -428,7 +450,7 @@ func TestRegistryPullCached(t *testing.T) {
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{6}
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
@@ -531,54 +553,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
}
}
func TestRegistryPullChunking(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
@@ -816,3 +790,79 @@ func TestUnlink(t *testing.T) {
}
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}

View File

@@ -1,11 +0,0 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@@ -1,107 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -1,6 +1,5 @@
// Package registry provides an http.Handler for handling local Ollama API
// requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
// Package registry implements an http.Handler for handling local Ollama API
// model management requests. See [Local] for details.
package registry
import (
@@ -10,6 +9,7 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"net/http"
"sync"
"time"
@@ -18,16 +18,11 @@ import (
"github.com/ollama/ollama/server/internal/client/ollama"
)
// Local is an http.Handler for handling local Ollama API requests for
// performing tasks related to the ollama.com model registry combined with the
// local disk cache.
// Local implements an http.Handler for handling local Ollama API model
// management requests, such as pushing, pulling, and deleting models.
//
// It is not concern of Local, or this package, to handle model creation, which
// proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
// It can be arranged for all unknown requests to be passed through to a
// fallback handler, if one is provided.
type Local struct {
Client *ollama.Registry // required
Logger *slog.Logger // required
@@ -63,6 +58,7 @@ func (e serverError) Error() string {
var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"}
errModelNotFound = &serverError{404, "not_found", "model not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"}
)
@@ -175,8 +171,16 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
}
type params struct {
DeprecatedName string `json:"name"` // Use [params.model]
Model string `json:"model"` // Use [params.model]
// DeprecatedName is the name of the model to push, pull, or delete,
// but is deprecated. New clients should use [Model] instead.
//
// Use [model()] to get the model name for both old and new API requests.
DeprecatedName string `json:"name"`
// Model is the name of the model to push, pull, or delete.
//
// Use [model()] to get the model name for both old and new API requests.
Model string `json:"model"`
// AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately.
@@ -189,9 +193,18 @@ type params struct {
// confusing flags such as this.
AllowNonTLS bool `json:"insecure"`
// ProgressStream is a flag that indicates the client is expecting a stream of
// progress updates.
ProgressStream bool `json:"stream"`
// Stream, if true, will make the server send progress updates in a
// streaming of JSON objects. If false, the server will send a single
// JSON object with the final status as "success", or an error object
// if an error occurred.
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively set it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
Stream *bool `json:"stream"`
}
// model returns the model name for both old and new API requests.
@@ -199,6 +212,13 @@ func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName)
}
func (p params) stream() bool {
if p.Stream == nil {
return true
}
return *p.Stream
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
@@ -212,16 +232,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return err
}
if !ok {
return &serverError{404, "not_found", "model not found"}
return errModelNotFound
}
if s.Prune == nil {
return nil
if s.Prune != nil {
return s.Prune()
}
return s.Prune()
return nil
}
type progressUpdateJSON struct {
Status string `json:"status"`
Status string `json:"status,omitempty,omitzero"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
@@ -237,6 +257,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return err
}
enc := json.NewEncoder(w)
if !p.stream() {
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return errModelNotFound
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
}
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
@@ -246,69 +277,74 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
defer maybeFlush()
var mu sync.Mutex
enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
progress := make(map[*ollama.Layer]int64)
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
mu.Lock()
defer mu.Unlock()
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() {
defer maybeFlush()
// TODO(bmizerany): coalesce these updates; writing per
// update is expensive
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progressCopy {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Status: "pulling",
Total: l.Size,
Completed: n,
})
}
}
defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
}
mu.Lock()
progress[l] += n
mu.Unlock()
},
})
done := make(chan error, 1)
go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model())
}()
func() {
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
mu.Lock()
maybeFlush()
mu.Unlock()
case err := <-done:
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
for {
select {
case <-t.C:
flushProgress()
case err := <-done:
flushProgress()
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
} else {
status = fmt.Sprintf("error: %v", err)
}
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
enc.Encode(progressUpdateJSON{Status: status})
}
return nil
}
}()
return nil
}
}
func decodeUserJSON[T any](r io.Reader) (T, error) {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"net"
@@ -160,7 +159,6 @@ var registryFS = sync.OnceValue(func() fs.FS {
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
@@ -179,7 +177,7 @@ func TestServerPull(t *testing.T) {
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving file: %s", r.URL.Path)
t.Logf("serving blob: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
@@ -188,7 +186,7 @@ func TestServerPull(t *testing.T) {
t.Helper()
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
t.Errorf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines)
@@ -197,35 +195,29 @@ func TestServerPull(t *testing.T) {
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Fatalf("! missing %q in body", want)
t.Errorf("! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Fatalf("! unexpected %q in body", want)
t.Errorf("! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"}
`)
@@ -240,19 +232,39 @@ func TestServerPull(t *testing.T) {
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""}
!verifying
!writing
!success
`)
// Non-streaming pulls
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
checkResponse(got, `
{"status":"success"}
!digest
!total
!completed
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
checkErrorResponse(t, got, 404, "not_found", "model not found")
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
var fellback bool
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fellback = true
})
got = s.send(t, "DELETE", "/api/unknown", `{}`)
if !fellback {
t.Fatal("expected Fallback to be called")
}
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
}
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {

View File

@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err)
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {

View File

@@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var system []api.Message
isMllama := checkMllamaModelFamily(m)
isGemma3 := checkGemma3ModelFamily(m)
var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
@@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool {
}
return false
}
func checkGemma3ModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "gemma3" {
return true
}
}
return false
}

View File

@@ -0,0 +1,13 @@
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<start_of_turn>user
{{- if and (eq $i 1) $.System }}
{{ $.System }}
{{ end }}
{{ .Content }}<end_of_turn>
{{ else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}<end_of_turn>
{{ end }}
{{- if $last }}<start_of_turn>model
{{ end }}
{{- end }}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"<end_of_turn>"
],
"temperature": 0.1
}

View File

@@ -87,6 +87,10 @@
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct"
},
{
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
"name": "gemma3-instruct"
},
{
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct"

View File

@@ -0,0 +1,10 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,4 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,8 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model