Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
73a1e99f8a logging: add a new customer logger and trace method
This change addresses over logging with debug in the SPM tokenizer by
adding a trace level to slog.
2025-03-13 16:10:59 -07:00
42 changed files with 1224 additions and 1304 deletions

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
} }
} }
], ],

View File

@@ -392,8 +392,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
### Cloud ### Cloud

View File

@@ -757,132 +757,3 @@ func TestCreateHandler(t *testing.T) {
}) })
} }
} }
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
) )
type MultilineState int type MultilineState int
@@ -460,16 +459,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{ req := &api.CreateRequest{
Model: name, Name: name,
From: cmp.Or(parentModel, opts.Model), From: cmp.Or(opts.ParentModel, opts.Model),
} }
if opts.System != "" { if opts.System != "" {

View File

@@ -187,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`. Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform. Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored? ## Where are models stored?

View File

@@ -583,52 +583,39 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "mllama": case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8 numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 + graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles + imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles + embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
} }
return weights, graphSize return weights, graphSize
} }

View File

@@ -66,35 +66,6 @@ func TestIntegrationMllama(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
} }
func TestIntegrationSplitBatch(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -37,7 +37,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
@@ -805,24 +804,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}, },
}, },
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
{ {

View File

@@ -41,7 +41,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3, LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA, LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,

View File

@@ -878,9 +878,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2540,9 +2537,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
} }
} break; } break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4035,7 +4029,6 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE: case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM: case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX: case LLM_ARCH_GPTNEOX:

View File

@@ -737,15 +737,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times. // This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.blk.") == std::string::npos;
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
// quantize only 2D and 3D tensors (experts) // quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2); quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -1,113 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Patrick Devine <patrick@infrahq.com>
Date: Fri, 14 Mar 2025 16:33:23 -0700
Subject: [PATCH] gemma3 quantization
---
src/llama-arch.cpp | 19 +++++++++++++++++++
src/llama-arch.h | 1 +
src/llama-model.cpp | 7 +++++++
src/llama-quant.cpp | 9 +++++++++
4 files changed, 36 insertions(+)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index b6f20286..b443fcd3 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index ec742224..aad92a5d 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ab1a07d1..70183041 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 6eb1da08..d2f3a510 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ // don't quantize vision stuff
+ quantize &= name.find("v.blk.") == std::string::npos;
+
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
+
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size() layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount() layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
} }
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU { if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU // Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights // memory of the weights
"total", format.HumanBytes2(m.memoryWeights), "total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers // memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights), "repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
// memory of non-repeating layers // memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput), "nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
), ),

View File

@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
} }
slog.Info("starting llama server", "cmd", s.cmd) slog.Info("starting llama server", "cmd", s.cmd.String())
if envconfig.Debug() { if envconfig.Debug() {
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
ServerStatusError ServerStatusError
) )
func (s ServerStatus) String() string { func (s ServerStatus) ToString() string {
switch s { switch s {
case ServerStatusReady: case ServerStatusReady:
return "llm server ready" return "llm server ready"
@@ -485,9 +485,12 @@ func (s ServerStatus) String() string {
} }
} }
type ServerStatusResponse struct { type ServerStatusResp struct {
Status ServerStatus `json:"status"` Status string `json:"status"`
Progress float32 `json:"progress"` SlotsIdle int `json:"slots_idle"`
SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"`
Progress float32 `json:"progress"`
} }
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -499,7 +502,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
} }
if s.cmd.ProcessState.ExitCode() == -1 { if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot // Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
} }
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
@@ -524,19 +527,21 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
return ServerStatusError, fmt.Errorf("read health request: %w", err) return ServerStatusError, fmt.Errorf("read health request: %w", err)
} }
var ssr ServerStatusResponse var status ServerStatusResp
if err := json.Unmarshal(body, &ssr); err != nil { if err := json.Unmarshal(body, &status); err != nil {
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
} }
switch ssr.Status { switch status.Status {
case ServerStatusLoadingModel: case "ok":
s.loadProgress = ssr.Progress return ServerStatusReady, nil
return ssr.Status, nil case "no slot available":
case ServerStatusReady, ServerStatusNoSlotsAvailable: return ServerStatusNoSlotsAvailable, nil
return ssr.Status, nil case "loading model":
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil
default: default:
return ssr.Status, fmt.Errorf("server error: %+v", ssr) return ServerStatusError, fmt.Errorf("server error: %+v", status)
} }
} }
@@ -611,7 +616,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
status, _ := s.getServerStatus(ctx) status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady { if lastStatus != status && status != ServerStatusReady {
// Only log on status changes // Only log on status changes
slog.Info("waiting for server to become available", "status", status) slog.Info("waiting for server to become available", "status", status.ToString())
} }
switch status { switch status {
case ServerStatusReady: case ServerStatusReady:
@@ -625,7 +630,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration) stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status) slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
stallTimer = time.Now().Add(stallDuration) stallTimer = time.Now().Add(stallDuration)
fullyLoaded = true fullyLoaded = true
} }
@@ -666,26 +671,63 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"` AspectRatioID int `json:"aspect_ratio_id"`
} }
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
}
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
Grammar string // set before sending the request to the subprocess
} }
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string
DoneReason string `json:"done_reason"` DoneReason string
Done bool `json:"done"` Done bool
PromptEvalCount int `json:"prompt_eval_count"` PromptEvalCount int
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration
EvalCount int `json:"eval_count"` EvalCount int
EvalDuration time.Duration `json:"eval_duration"` EvalDuration time.Duration
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
"n_predict": req.Options.NumPredict,
"n_keep": req.Options.NumKeep,
"main_gpu": req.Options.MainGPU,
"temperature": req.Options.Temperature,
"top_k": req.Options.TopK,
"top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN,
"repeat_penalty": req.Options.RepeatPenalty,
"presence_penalty": req.Options.PresencePenalty,
"frequency_penalty": req.Options.FrequencyPenalty,
"mirostat": req.Options.Mirostat,
"mirostat_tau": req.Options.MirostatTau,
"mirostat_eta": req.Options.MirostatEta,
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"cache_prompt": true,
}
if len(req.Format) > 0 { if len(req.Format) > 0 {
switch string(req.Format) { switch string(req.Format) {
case `null`, `""`: case `null`, `""`:
@@ -693,7 +735,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set". // these as "not set".
break break
case `"json"`: case `"json"`:
req.Grammar = grammarJSON request["grammar"] = grammarJSON
default: default:
if req.Format[0] != '{' { if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -704,15 +746,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if g == nil { if g == nil {
return fmt.Errorf("invalid JSON schema in format") return fmt.Errorf("invalid JSON schema in format")
} }
req.Grammar = string(g) request["grammar"] = string(g)
} }
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection") slog.Info("aborting completion request due to client closing the connection")
@@ -733,7 +770,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
@@ -741,7 +778,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
if err := enc.Encode(req); err != nil { if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err) return fmt.Errorf("failed to marshal data: %v", err)
} }
@@ -792,7 +829,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
evt = line evt = line
} }
var c CompletionResponse var c completion
if err := json.Unmarshal(evt, &c); err != nil { if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
} }
@@ -816,8 +853,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}) })
} }
if c.Done { if c.Stop {
fn(c) doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
return nil return nil
} }
} }
@@ -865,7 +914,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(EmbeddingRequest{Content: input}) data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1010,3 +1059,12 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
} }
return 0 return 0
} }
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return dur
}

40
logging/log.go Normal file
View File

@@ -0,0 +1,40 @@
package logging
import (
"context"
"log/slog"
"os"
)
const LevelTrace slog.Level = slog.LevelDebug - 4
type Logger struct {
logger *slog.Logger
}
func NewLogger() *Logger {
handler := slog.NewTextHandler(os.Stdout, nil)
return &Logger{
logger: slog.New(handler),
}
}
func (l *Logger) Trace(msg string, args ...any) {
l.logger.Log(context.Background(), LevelTrace, msg, args...)
}
func (l *Logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *Logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *Logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *Logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}

View File

@@ -312,19 +312,17 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name) return fmt.Errorf("unassigned tensor: %s", t.Name)
} }
bts := C.malloc(C.size_t(t.Size())) bts := make([]byte, t.Size())
if bts == nil { n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
return errors.New("failed to allocate tensor buffer") if err != nil {
} return err
defer C.free(bts)
buf := unsafe.Slice((*byte)(bts), t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
if err != nil || n != len(buf) {
return errors.New("read failed")
} }
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) if n != len(bts) {
return errors.New("short read")
}
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil return nil
}) })
} }
@@ -373,7 +371,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)), C.int(len(schedBackends)),
C.size_t(maxGraphNodes), C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), true,
), ),
input: deviceBufferTypes[input.d], input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d], output: deviceBufferTypes[output.d],

View File

@@ -15,12 +15,6 @@ type Input struct {
// stored in Multimodal, used for caching and comparing // stored in Multimodal, used for caching and comparing
// equality. // equality.
MultimodalHash uint64 MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
} }
// MultimodalIndex is a multimodal element (such as an image) // MultimodalIndex is a multimodal element (such as an image)

View File

@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal // This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately // that is modified to ensure that there is a unique hash value that accurately
// represents the contents. // represents the contents.
PostTokenize([]input.Input) ([]input.Input, error) PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
} }
// Base implements the common fields and methods for all models // Base implements the common fields and methods for all models

View File

@@ -2,9 +2,10 @@ package gemma3
import ( import (
"bytes" "bytes"
"encoding/binary"
"hash/fnv"
"image" "image"
"math" "math"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@@ -111,23 +112,36 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil return visionOutputs, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var result []input.Input var result []input.Input
fnvHash := fnv.New64a()
for _, inp := range inputs { for _, inp := range inputs {
if inp.Multimodal == nil { if inp.Multimodal == nil {
result = append(result, inp) result = append(result, inp)
} else { } else {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
result = append(result, imageInputs...)
// add image embeddings
inputMultimodal := inp.Multimodal.(ml.Tensor) inputMultimodal := inp.Multimodal.(ml.Tensor)
result = append(result, for i := range inputMultimodal.Dim(1) {
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" fnvHash.Reset()
input.Input{Token: 255999}, // "<start_of_image>"" binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder fnvHash.Write([]byte{byte(i)})
)
// add image token placeholders imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
result = append(result, result = append(result,
input.Input{Token: 256000}, // <end_of_image> input.Input{Token: 256000}, // <end_of_image>

View File

@@ -171,20 +171,53 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, imageDst)
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
var except []int
for _, image := range opts.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
}
}
for i, layer := range m.Layers { for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global) // gemma alternates between the sliding window (local) and causal (global)

View File

@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil return m.Projector.Forward(ctx, crossAttentionStates), nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input var images []input.Input
fnvHash := fnv.New64a() fnvHash := fnv.New64a()
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal == nil { if inputs[i].Multimodal == nil {
if len(images) > 0 { if len(images) > 0 {
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ { for j := 1; j < len(images); j++ {
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset() fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@@ -138,10 +138,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 { if len(opts.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
} }
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

View File

@@ -2,15 +2,18 @@ package model
import ( import (
"iter" "iter"
"log/slog"
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue" queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
"github.com/ollama/ollama/logging"
) )
const spmWhitespaceSep = "▁" const spmWhitespaceSep = "▁"
var log = logging.NewLogger()
func replaceWhitespaceBySeperator(s string) string { func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep) return strings.ReplaceAll(s, " ", spmWhitespaceSep)
} }
@@ -24,7 +27,7 @@ type SentencePieceModel struct {
var _ TextProcessor = (*SentencePieceModel)(nil) var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
var maxTokenLen int var maxTokenLen int
@@ -38,7 +41,7 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
} }
} }
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
@@ -91,7 +94,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
} }
} }
slog.Debug("fragments", "frags", fragments) log.Trace("fragments", "frags", fragments)
var ids []int32 var ids []int32
for _, frag := range fragments { for _, frag := range fragments {
@@ -129,7 +132,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
} }
} }
slog.Debug("tokenizer", "merges", merges) log.Trace("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate { pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) { if a < 0 || b >= len(runes) {
@@ -156,7 +159,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
pqv := pq.Values() pqv := pq.Values()
for _, v := range pqv { for _, v := range pqv {
e := v.(*candidate) e := v.(*candidate)
slog.Debug("candidate", "candidate", e) log.Trace("candidate", "candidate", e)
} }
for !pq.Empty() { for !pq.Empty() {
@@ -164,7 +167,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
pair := v.(*candidate) pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b] left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right) log.Trace("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 { if len(left.runes) == 0 || len(right.runes) == 0 {
continue continue
} }
@@ -189,14 +192,14 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
} }
} }
slog.Debug("merges", "merges", merges) log.Trace("merges", "merges", merges)
for _, merge := range merges { for _, merge := range merges {
if len(merge.runes) > 0 { if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
} else { } else {
slog.Debug("missing token", "token", string(merge.runes)) log.Error("missing token", "token", string(merge.runes))
} }
} }
} }
@@ -206,19 +209,19 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
if addSpecial && len(ids) > 0 { if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS { if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS { if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS) log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
} }
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS) log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...) ids = append([]int32{spm.vocab.BOS}, ids...)
} }
if spm.vocab.AddEOS { if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS { if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS) log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
} }
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS) log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS) ids = append(ids, spm.vocab.EOS)
} }
} }
@@ -241,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
} }
} }
slog.Debug("decoded", "ids", ids, "text", sb.String()) log.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil return sb.String(), nil
} }

View File

@@ -24,7 +24,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
) )
@@ -100,7 +99,7 @@ type NewSequenceParams struct {
embedding bool embedding bool
} }
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
@@ -164,7 +163,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image // generating image embeddings for each image
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input var inputs []input
var parts []string var parts []string
var matches [][]string var matches [][]string
@@ -230,7 +229,7 @@ type Server struct {
image *ImageContext image *ImageContext
// status for external health reporting - loading, ready to serve, etc. // status for external health reporting - loading, ready to serve, etc.
status llm.ServerStatus status ServerStatus
// current progress on loading the model // current progress on loading the model
progress float32 progress float32
@@ -542,18 +541,75 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil return nil
} }
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req llm.CompletionRequest var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming // Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@@ -564,28 +620,26 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
// Extract options from the CompletionRequest var samplingParams llama.SamplingParams
samplingParams := llama.SamplingParams{ samplingParams.TopK = req.TopK
TopK: req.Options.TopK, samplingParams.TopP = req.TopP
TopP: req.Options.TopP, samplingParams.MinP = req.MinP
MinP: req.Options.MinP, samplingParams.TypicalP = req.TypicalP
TypicalP: req.Options.TypicalP, samplingParams.Temp = req.Temperature
Temp: req.Options.Temperature, samplingParams.RepeatLastN = req.RepeatLastN
RepeatLastN: req.Options.RepeatLastN, samplingParams.PenaltyRepeat = req.RepeatPenalty
PenaltyRepeat: req.Options.RepeatPenalty, samplingParams.PenaltyFreq = req.FrequencyPenalty
PenaltyFreq: req.Options.FrequencyPenalty, samplingParams.PenaltyPresent = req.PresencePenalty
PenaltyPresent: req.Options.PresencePenalty, samplingParams.Mirostat = req.Mirostat
Mirostat: req.Options.Mirostat, samplingParams.MirostatTau = req.MirostatTau
MirostatTau: req.Options.MirostatTau, samplingParams.MirostatEta = req.MirostatEta
MirostatEta: req.Options.MirostatEta, samplingParams.Seed = uint32(req.Seed)
Seed: uint32(req.Options.Seed), samplingParams.Grammar = req.Grammar
Grammar: req.Grammar,
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict, numPredict: req.NumPredict,
stop: req.Options.Stop, stop: req.Stop,
numKeep: req.Options.NumKeep, numKeep: req.NumKeep,
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
}) })
@@ -608,7 +662,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -637,7 +691,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content, Content: content,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -648,17 +702,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response // Send the final response
doneReason := "stop" if err := json.NewEncoder(w).Encode(&CompletionResponse{
if seq.doneReason == "limit" { Stop: true,
doneReason = "length" StoppedLimit: seq.doneReason == "limit",
} Timings: Timings{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ PromptN: seq.numPromptInputs,
Done: true, PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
DoneReason: doneReason, PredictedN: seq.numDecoded,
PromptEvalCount: seq.numPromptInputs, PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), },
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
} }
@@ -669,8 +721,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req llm.EmbeddingRequest var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return return
@@ -700,7 +761,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -721,17 +782,41 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding, Embedding: embedding,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }
} }
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) { func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status, Status: s.status.ToString(),
Progress: s.progress, Progress: s.progress,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -794,7 +879,7 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
s.status = llm.ServerStatusReady s.status = ServerStatusReady
s.ready.Done() s.ready.Done()
} }
@@ -852,7 +937,7 @@ func Execute(args []string) error {
parallel: *parallel, parallel: *parallel,
seqs: make([]*Sequence, *parallel), seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)), seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: llm.ServerStatusLoadingModel, status: ServerStatusLoadingModel,
} }
var tensorSplitFloats []float32 var tensorSplitFloats []float32

View File

@@ -89,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@@ -107,6 +107,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return nil, nil, err return nil, nil, err
} }
if !cachePrompt {
numPast = 0
}
slot.InUse = true slot.InUse = true
slot.lastUsed = time.Now() slot.lastUsed = time.Now()

View File

@@ -297,131 +297,3 @@ func TestShiftDiscard(t *testing.T) {
}) })
} }
} }
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -24,7 +24,6 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
@@ -34,14 +33,10 @@ import (
_ "github.com/ollama/ollama/model/models" _ "github.com/ollama/ollama/model/models"
) )
type contextList struct {
list []ml.Context
}
type Sequence struct { type Sequence struct {
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as // ctx for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings // multimodal embeddings
ctxs *contextList ctx ml.Context
// batch index // batch index
iBatch int iBatch int
@@ -99,12 +94,13 @@ type NewSequenceParams struct {
embedding bool embedding bool
} }
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, ctxs, err := s.inputs(prompt, images) inputs, err := s.inputs(ctx, prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
@@ -115,9 +111,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs)) params.numKeep = int32(len(inputs))
} }
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift // Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1) params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -133,7 +126,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
return &Sequence{ return &Sequence{
ctxs: ctxs, ctx: ctx,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
@@ -152,7 +145,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) { func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
var inputs []input.Input var inputs []input.Input
var parts []string var parts []string
var matches [][]string var matches [][]string
@@ -167,19 +160,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
parts = []string{prompt} parts = []string{prompt}
} }
var contexts contextList
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
for _, ctx := range ctxs {
ctx.Close()
}
}, contexts.list)
postTokenize := false postTokenize := false
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
for _, t := range tokens { for _, t := range tokens {
@@ -199,14 +185,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
} }
if imageIndex < 0 { if imageIndex < 0 {
return nil, nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
ctx := s.model.Backend().NewContext()
contexts.list = append(contexts.list, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
s.multimodalHash.Reset() s.multimodalHash.Reset()
@@ -220,13 +204,13 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
if visionModel && postTokenize { if visionModel && postTokenize {
var err error var err error
inputs, err = multimodalProcessor.PostTokenize(inputs) inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
} }
return inputs, &contexts, nil return inputs, nil
} }
type Server struct { type Server struct {
@@ -238,7 +222,7 @@ type Server struct {
model model.Model model model.Model
// status for external health reporting - loading, ready to serve, etc. // status for external health reporting - loading, ready to serve, etc.
status llm.ServerStatus status ServerStatus
// current progress on loading the model // current progress on loading the model
progress float32 progress float32
@@ -321,6 +305,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
@@ -366,33 +351,20 @@ func (s *Server) processBatch() error {
seq.cache.Inputs = []input.Input{} seq.cache.Inputs = []input.Input{}
} }
batchSize := s.batchSize
for j, inp := range seq.inputs { for j, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
// batch size. Since we are only extending the size the minimum amount possible, this if len(seq.pendingInputs) == 0 {
// will cause a break if we have pending inputs. err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
minBatch := 1 + inp.SameBatch if err != nil {
if minBatch > batchSize { return err
batchSize = minBatch }
} } else {
if len(seq.pendingInputs)+minBatch > batchSize {
break
}
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break break
} }
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if j >= s.batchSize {
if err != nil { break
return err
}
} }
options.Inputs = append(options.Inputs, inp.Token) options.Inputs = append(options.Inputs, inp.Token)
@@ -529,18 +501,75 @@ func (s *Server) processBatch() error {
return nil return nil
} }
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req llm.CompletionRequest var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming // Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@@ -562,18 +591,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
sampler := sample.NewSampler( sampler := sample.NewSampler(
req.Options.Temperature, req.Temperature,
req.Options.TopK, req.TopK,
req.Options.TopP, req.TopP,
req.Options.MinP, req.MinP,
req.Options.Seed, req.Seed,
grammar, grammar,
) )
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict, numPredict: req.NumPredict,
stop: req.Options.Stop, stop: req.Stop,
numKeep: int32(req.Options.NumKeep), numKeep: int32(req.NumKeep),
sampler: sampler, sampler: sampler,
embedding: false, embedding: false,
}) })
@@ -596,7 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -623,7 +652,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content, Content: content,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -634,17 +663,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response // Send the final response
doneReason := "stop" if err := json.NewEncoder(w).Encode(&CompletionResponse{
if seq.doneReason == "limit" { Stop: true,
doneReason = "length" StoppedLimit: seq.doneReason == "limit",
} Timings: Timings{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ PromptN: seq.numPromptInputs,
Done: true, PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
DoneReason: doneReason, PredictedN: seq.numPredicted,
PromptEvalCount: seq.numPromptInputs, PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), },
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
} }
@@ -655,10 +682,43 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) { func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status, Status: s.status.ToString(),
Progress: s.progress, Progress: s.progress,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -712,7 +772,7 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
s.status = llm.ServerStatusReady s.status = ServerStatusReady
s.ready.Done() s.ready.Done()
} }
@@ -764,7 +824,7 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
batchSize: *batchSize, batchSize: *batchSize,
status: llm.ServerStatusLoadingModel, status: ServerStatusLoadingModel,
} }
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:

View File

@@ -87,9 +87,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits // topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK) tokens = topK(tokens, s.topK)
// scale and normalize the tokens in place tokens = temperature(tokens, s.temperature)
temperature(tokens, s.temperature) tokens = softmax(tokens)
softmax(tokens)
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)

View File

@@ -26,16 +26,17 @@ func (h *tokenHeap) Pop() any {
} }
// temperature applies scaling to the logits // temperature applies scaling to the logits
func temperature(ts []token, temp float32) { func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability // Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7) temp = max(temp, 1e-7)
for i := range ts { for i := range ts {
ts[i].value = ts[i].value / temp ts[i].value = ts[i].value / temp
} }
return ts
} }
// softmax applies normalization to the logits // softmax applies normalization to the logits
func softmax(ts []token) { func softmax(ts []token) []token {
// Find max logit for numerical stability // Find max logit for numerical stability
maxLogit := float32(math.Inf(-1)) maxLogit := float32(math.Inf(-1))
for _, t := range ts { for _, t := range ts {
@@ -55,6 +56,8 @@ func softmax(ts []token) {
for i := range ts { for i := range ts {
ts[i].value /= sum ts[i].value /= sum
} }
return ts
} }
// topK limits the number of tokens considered to the k highest logits // topK limits the number of tokens considered to the k highest logits
@@ -96,7 +99,6 @@ func topK(ts []token, k int) []token {
} }
// topP limits tokens to those with cumulative probability p // topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token { func topP(ts []token, p float32) []token {
if p == 1.0 { if p == 1.0 {
return ts return ts
@@ -107,24 +109,37 @@ func topP(ts []token, p float32) []token {
for i, t := range ts { for i, t := range ts {
sum += t.value sum += t.value
if sum > float32(p) { if sum > float32(p) {
return ts[:i+1] ts = ts[:i+1]
return ts
} }
} }
return ts return ts
} }
// minP filters tokens with probabilities >= p * max_prob // minP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token { func minP(ts []token, p float32) []token {
maxProb := ts[0].value if p == 1.0 {
return ts
}
threshold := maxProb * p maxProb := float32(math.Inf(-1))
for _, token := range ts {
for i, t := range ts { if token.value > maxProb {
if t.value < threshold { maxProb = token.value
return ts[:i]
} }
} }
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts return ts
} }

View File

@@ -34,22 +34,17 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0} input := []float32{1.0, 4.0, -2.0, 0.0}
tokens := toTokens(input) got := temperature(toTokens(input), 0.5)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0} want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, tokens) compareLogits(t, "temperature(0.5)", want, got)
input = []float32{1.0, 4.0, -2.0, 0.0} got = temperature(toTokens(input), 1.0)
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0} want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, tokens) compareLogits(t, "temperature(1)", want, got)
input = []float32{1.0, 4.0, -2.0, 0.0} got = temperature(toTokens(input), 0.0)
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0} want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, tokens) compareLogits(t, "temperature(0)", want, got)
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
@@ -95,17 +90,16 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tokens := toTokens(tt.input) got := softmax(toTokens(tt.input))
softmax(tokens)
if tt.expected != nil { if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, tokens) compareLogits(t, tt.name, tt.expected, got)
return return
} }
// Check probabilities sum to 1 // Check probabilities sum to 1
var sum float32 var sum float32
for _, token := range tokens { for _, token := range got {
sum += token.value sum += token.value
if token.value < 0 || token.value > 1 { if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value) t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -120,44 +114,38 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
tokens = topK(tokens, 5) // Test k=5
if len(tokens) != 5 { got := topK(toTokens(input), 5)
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens)) if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
} }
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154} want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, tokens) compareLogits(t, "topK(3)", want, got)
tokens = toTokens(input) got = topK(toTokens(input), 20)
tokens = topK(tokens, 20) if len(got) != len(input) {
if len(tokens) != len(input) { t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
} }
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input) got = topK(toTokens(input), -1)
tokens = topK(tokens, -1) if len(got) != len(input) {
if len(tokens) != len(input) { t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
} }
compareLogits(t, "topK(-1)", want, tokens) compareLogits(t, "topK(-1)", want, got)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input) got = topK(toTokens(input), 0)
tokens = topK(tokens, 0) if len(got) != len(input) {
if len(tokens) != len(input) { t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
} }
compareLogits(t, "topK(-1)", want, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
@@ -165,25 +153,16 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax to get probabilities // First apply temperature and softmax to get probabilities
softmax(tokens) tokens = softmax(tokens)
tokens = topK(tokens, 20) tokens = topK(tokens, 20)
// Then apply topP // Then apply topP
tokens = topP(tokens, 0.95) got := topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95 // Should keep tokens until cumsum > 0.95
if len(tokens) > 3 { if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", tokens) t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e6} // One dominant token
tokens = toTokens(input)
softmax(tokens)
tokens = topP(tokens, 0.0) // Very small p
if len(tokens) < 1 {
t.Error("topP should keep at least one token")
} }
} }
@@ -192,45 +171,14 @@ func TestMinP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax // First apply temperature and softmax
tokens = topK(tokens, 20) tokens = softmax(tokens)
softmax(tokens)
tokens = minP(tokens, 1.0) // Then apply minP
got := minP(tokens, 0.2)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob // Should keep tokens with prob >= 0.2 * max_prob
if len(tokens) > 3 { if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", tokens)
}
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
} }
} }
@@ -283,7 +231,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 10) topK(tokensCopy, 10)
} }
}) })
@@ -291,7 +239,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
tokens = topP(tokensCopy, 0.9) topP(tokensCopy, 0.9)
} }
}) })
@@ -299,7 +247,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
tokens = minP(tokensCopy, 0.2) minP(tokensCopy, 0.2)
} }
}) })
@@ -307,7 +255,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 200000) topK(tokensCopy, 200000)
} }
}) })
} }

View File

@@ -8,7 +8,7 @@ usage() {
exit 1 exit 1
} }
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")} export VERSION=${VERSION:-$(git describe --tags --dirty)}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CPPFLAGS='-mmacosx-version-min=11.3' export CGO_CPPFLAGS='-mmacosx-version-min=11.3'

View File

@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
// be in either of the following forms: // be in either of the following forms:
// //
// @<digest> // @<digest>
// <name>@<digest> // <name>
// <name> // <name>
// //
// If a digest is provided, it is returned as is and nothing else happens. // If a digest is provided, it is returned as is and nothing else happens.
@@ -160,6 +160,8 @@ func debugger(err *error) func(step string) {
// hashed is passed to a PutBytes call to ensure that the manifest is in the // hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in // blob store. This is done to ensure that future calls to [Get] succeed in
// these cases. // these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) { func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name) name, digest := splitNameDigest(name)
if digest != "" { if digest != "" {
@@ -277,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link // It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues. // creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error { func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name) manifest, err := c.manifestPath(name)
if err != nil { if err != nil {
return err return err
@@ -327,9 +341,7 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename) return absJoin(c.dir, "blobs", filename)
} }
// Links returns a sequence of link names. The sequence is in lexical order. // Links returns a sequence of links in the cache in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] { func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) { return func(yield func(string, error) bool) {
for path, err := range c.links() { for path, err := range c.links() {
@@ -402,14 +414,12 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
} }
type checkWriter struct { type checkWriter struct {
size int64
d Digest d Digest
f *os.File size int64
n int64
h hash.Hash h hash.Hash
f *os.File
w io.Writer // underlying writer; set by creator err error
n int64
err error
testHookBeforeFinalWrite func(*os.File) testHookBeforeFinalWrite func(*os.File)
} }
@@ -425,10 +435,6 @@ func (w *checkWriter) seterr(err error) error {
// underlying writer is guaranteed to be the last byte of p as verified by the // underlying writer is guaranteed to be the last byte of p as verified by the
// hash. // hash.
func (w *checkWriter) Write(p []byte) (int, error) { func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p) _, err := w.h.Write(p)
if err != nil { if err != nil {
return 0, w.seterr(err) return 0, w.seterr(err)
@@ -447,7 +453,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
if nextSize > w.size { if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size)) return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
} }
n, err := w.w.Write(p) n, err := w.f.Write(p)
w.n += int64(n) w.n += int64(n)
return n, w.seterr(err) return n, w.seterr(err)
} }
@@ -487,12 +493,10 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
// Copy file to f, but also into h to double-check hash. // Copy file to f, but also into h to double-check hash.
cw := &checkWriter{ cw := &checkWriter{
d: out, d: out,
size: size, size: size,
h: sha256.New(), h: sha256.New(),
f: f, f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite, testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
} }
n, err := io.Copy(cw, file) n, err := io.Copy(cw, file)
@@ -528,6 +532,11 @@ func splitNameDigest(s string) (name, digest string) {
var errInvalidName = errors.New("invalid name") var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) { func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name) n := names.Parse(name)
if !n.IsFullyQualified() { if !n.IsFullyQualified() {
return "", errInvalidName return "", errInvalidName
@@ -538,7 +547,8 @@ func nameToPath(name string) (_ string, err error) {
func absJoin(pp ...string) string { func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...)) abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil { if err != nil {
panic(err) // this should never happen // Likely a bug bug or a bad OS problem. Just panic.
panic(err)
} }
return abs return abs
} }

View File

@@ -1,73 +0,0 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
)
// Chunk represents a range of bytes in a blob.
type Chunk struct {
Start int64
End int64
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

View File

@@ -63,10 +63,6 @@ func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4]) return fmt.Sprintf("%x", d.sum[:4])
} }
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int { func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:]) return slices.Compare(d.sum[:], other.sum[:])
} }

View File

@@ -0,0 +1,78 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@@ -0,0 +1,65 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@@ -19,13 +19,11 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"iter"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/debug"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -37,8 +35,10 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff" "github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names" "github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed" _ "embed"
) )
@@ -66,7 +66,12 @@ var (
const ( const (
// DefaultChunkingThreshold is the threshold at which a layer should be // DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading. // split up into chunks when downloading.
DefaultChunkingThreshold = 64 << 20 DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
) )
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
@@ -206,7 +211,8 @@ type Registry struct {
// pushing or pulling models. If zero, the number of streams is // pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS]. // determined by [runtime.GOMAXPROCS].
// //
// A negative value means no limit. // Clients that want "unlimited" streams should set this to a large
// number.
MaxStreams int MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single // ChunkingThreshold is the maximum size of a layer to download in a single
@@ -260,7 +266,6 @@ func DefaultRegistry() (*Registry, error) {
} }
var rc Registry var rc Registry
rc.UserAgent = UserAgent()
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -276,24 +281,25 @@ func DefaultRegistry() (*Registry, error) {
return &rc, nil return &rc, nil
} }
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
buildinfo.Main.Version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
)
}
func (r *Registry) maxStreams() int { func (r *Registry) maxStreams() int {
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
} }
func (r *Registry) maxChunkingThreshold() int64 { func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold) return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
} }
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
type PushParams struct { type PushParams struct {
// From is an optional destination name for the model. If empty, the // From is an optional destination name for the model. If empty, the
// destination name is the same as the source name. // destination name is the same as the source name.
@@ -420,21 +426,6 @@ func canRetry(err error) bool {
return re.Status >= 500 return re.Status >= 500
} }
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
r io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.n.Add(int64(n))
return
}
// Pull pulls the model with the given name from the remote registry into the // Pull pulls the model with the given name from the remote registry into the
// cache. // cache.
// //
@@ -443,6 +434,11 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
// typically slower than splitting the model up across layers, and is mostly // typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image". // utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error { func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name) m, err := r.Resolve(ctx, name)
if err != nil { if err != nil {
return err return err
@@ -461,95 +457,126 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err == nil && info.Size == l.Size return err == nil && info.Size == l.Size
} }
t := traceFromContext(ctx)
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
layers := m.Layers layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() { if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config) layers = append(layers, m.Config)
} }
// Send initial layer trace events to allow clients to have an for _, l := range layers {
// understanding of work to be done before work starts.
t := traceFromContext(ctx)
skip := make([]bool, len(layers))
for i, l := range layers {
t.update(l, 0, nil)
if exists(l) { if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
}
}
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
for i, l := range layers {
if skip[i] {
continue continue
} }
chunked, err := c.Chunked(l.Digest, l.Size) blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
if err != nil { if err != nil {
t.update(l, 0, err) t.update(l, 0, err)
continue continue
} }
defer chunked.Close()
var progress atomic.Int64 t.update(l, 0, nil)
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if l.Size <= r.maxChunkingThreshold() {
t.update(l, progress.Load(), err) g.Go(func() error {
break // TODO(bmizerany): retry/backoff like below in
} // the chunking case
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }() defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards
// progress, as they arrive, so
// that our bytes piggyback
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
}) })
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
wp := writerPool{size: r.maxChunkSize()}
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take()
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := wp.get()
tw.Reset(ticket)
defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
} }
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
@@ -588,6 +615,8 @@ type Manifest struct {
Config *Layer `json:"config"` Config *Layer `json:"config"`
} }
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given // Layer returns the layer with the given
// digest, or nil if not found. // digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer { func (m *Manifest) Layer(d blob.Digest) *Layer {
@@ -614,9 +643,10 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
// last phase of the commit which expects it, but does nothing // last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of // with it. This will be fixed in a future release of
// ollama.com. // ollama.com.
Config Layer `json:"config"` Config *Layer `json:"config"`
}{ }{
M: M(m), M: M(m),
Config: &Layer{Digest: emptyDigest},
} }
return json.Marshal(v) return json.Marshal(v)
} }
@@ -706,123 +736,6 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
return m, nil return m, nil
} }
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
//
// Example:
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
//
// The blobURL is the URL to download the chunks from.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client { func (r *Registry) client() *http.Client {
if r.HTTPClient != nil { if r.HTTPClient != nil {
return r.HTTPClient return r.HTTPClient
@@ -985,6 +898,13 @@ func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum) return fmt.Sprintf("GET,%s,%s", url, zeroSum)
} }
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
type publicError struct { type publicError struct {
wrapped error wrapped error
message string message string
@@ -1071,22 +991,27 @@ func splitExtended(s string) (scheme, name, digest string) {
return scheme, s, digest return scheme, s, digest
} }
// parseChunk parses a string in the form "start-end" and returns the Chunk. type writerPool struct {
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) { size int64 // set by the caller
startPart, endPart, found := strings.Cut(string(s), "-")
if !found { mu sync.Mutex
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s) ws []*bufio.Writer
} }
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil { func (p *writerPool) get() *bufio.Writer {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err) p.mu.Lock()
} defer p.mu.Unlock()
end, err := strconv.ParseInt(endPart, 10, 64) if len(p.ws) == 0 {
if err != nil { return bufio.NewWriterSize(nil, int(p.size))
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err) }
} w := p.ws[len(p.ws)-1]
if start > end { p.ws = p.ws[:len(p.ws)-1]
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s) return w
} }
return blob.Chunk{Start: start, End: end}, nil
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
} }

View File

@@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
) )
@@ -427,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
err := rc.Pull(ctx, "single") err := rc.Pull(ctx, "single")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{0, 6} want := []int64{6}
if !errors.Is(errors.Join(errs...), ErrCached) { if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached) t.Errorf("errs = %v; want %v", errs, ErrCached)
} }
@@ -530,6 +531,54 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
} }
} }
func TestRegistryPullChunking(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) { func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)

View File

@@ -0,0 +1,11 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -1,5 +1,6 @@
// Package registry implements an http.Handler for handling local Ollama API // Package registry provides an http.Handler for handling local Ollama API
// model management requests. See [Local] for details. // requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
package registry package registry
import ( import (
@@ -9,7 +10,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -18,11 +18,16 @@ import (
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
) )
// Local implements an http.Handler for handling local Ollama API model // Local is an http.Handler for handling local Ollama API requests for
// management requests, such as pushing, pulling, and deleting models. // performing tasks related to the ollama.com model registry combined with the
// local disk cache.
// //
// It can be arranged for all unknown requests to be passed through to a // It is not concern of Local, or this package, to handle model creation, which
// fallback handler, if one is provided. // proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
type Local struct { type Local struct {
Client *ollama.Registry // required Client *ollama.Registry // required
Logger *slog.Logger // required Logger *slog.Logger // required
@@ -58,7 +63,6 @@ func (e serverError) Error() string {
var ( var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"} errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"} errNotFound = &serverError{404, "not_found", "not found"}
errModelNotFound = &serverError{404, "not_found", "model not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"} errInternalError = &serverError{500, "internal_error", "internal server error"}
) )
@@ -171,16 +175,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
} }
type params struct { type params struct {
// DeprecatedName is the name of the model to push, pull, or delete, DeprecatedName string `json:"name"` // Use [params.model]
// but is deprecated. New clients should use [Model] instead. Model string `json:"model"` // Use [params.model]
//
// Use [model()] to get the model name for both old and new API requests.
DeprecatedName string `json:"name"`
// Model is the name of the model to push, pull, or delete.
//
// Use [model()] to get the model name for both old and new API requests.
Model string `json:"model"`
// AllowNonTLS is a flag that indicates a client using HTTP // AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately. // is doing so, deliberately.
@@ -193,18 +189,9 @@ type params struct {
// confusing flags such as this. // confusing flags such as this.
AllowNonTLS bool `json:"insecure"` AllowNonTLS bool `json:"insecure"`
// Stream, if true, will make the server send progress updates in a // ProgressStream is a flag that indicates the client is expecting a stream of
// streaming of JSON objects. If false, the server will send a single // progress updates.
// JSON object with the final status as "success", or an error object ProgressStream bool `json:"stream"`
// if an error occurred.
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
Stream *bool `json:"stream"`
} }
// model returns the model name for both old and new API requests. // model returns the model name for both old and new API requests.
@@ -212,13 +199,6 @@ func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName) return cmp.Or(p.Model, p.DeprecatedName)
} }
func (p params) stream() bool {
if p.Stream == nil {
return true
}
return *p.Stream
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" { if r.Method != "DELETE" {
return errMethodNotAllowed return errMethodNotAllowed
@@ -232,16 +212,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return err return err
} }
if !ok { if !ok {
return errModelNotFound return &serverError{404, "not_found", "model not found"}
} }
if s.Prune != nil { if s.Prune == nil {
return s.Prune() return nil
} }
return nil return s.Prune()
} }
type progressUpdateJSON struct { type progressUpdateJSON struct {
Status string `json:"status,omitempty,omitzero"` Status string `json:"status"`
Digest blob.Digest `json:"digest,omitempty,omitzero"` Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"` Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"` Completed int64 `json:"completed,omitempty,omitzero"`
@@ -257,17 +237,6 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
enc := json.NewEncoder(w)
if !p.stream() {
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return errModelNotFound
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
}
maybeFlush := func() { maybeFlush := func() {
fl, _ := w.(http.Flusher) fl, _ := w.(http.Flusher)
if fl != nil { if fl != nil {
@@ -277,67 +246,69 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
defer maybeFlush() defer maybeFlush()
var mu sync.Mutex var mu sync.Mutex
progress := make(map[*ollama.Layer]int64) enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
progressCopy := make(map[*ollama.Layer]int64, len(progress)) ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
pushUpdate := func() { Update: func(l *ollama.Layer, n int64, err error) {
defer maybeFlush() mu.Lock()
defer mu.Unlock()
// TODO(bmizerany): This scales poorly with more layers due to // TODO(bmizerany): coalesce these updates; writing per
// needing to flush out them all in one big update. We _could_ // update is expensive
// just flush on the changed ones, or just track the whole
// download. Needs more thought. This is fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progress {
enc.Encode(progressUpdateJSON{ enc.Encode(progressUpdateJSON{
Digest: l.Digest, Digest: l.Digest,
Status: "pulling",
Total: l.Size, Total: l.Size,
Completed: n, Completed: n,
}) })
}
}
t := time.NewTicker(time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
pushUpdate()
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
start() // flush initial state
}
mu.Lock()
progress[l] = n
mu.Unlock()
}, },
}) })
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model()) done <- s.Client.Pull(ctx, p.model())
}() }()
for { func() {
select { t := time.NewTicker(100 * time.Millisecond)
case <-t.C: defer t.Stop()
pushUpdate() for {
case err := <-done: select {
pushUpdate() case <-t.C:
if err != nil { mu.Lock()
var status string maybeFlush()
if errors.Is(err, ollama.ErrModelNotFound) { mu.Unlock()
status = fmt.Sprintf("error: model %q not found", p.model()) case err := <-done:
} else { if err != nil {
status = fmt.Sprintf("error: %v", err) var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
} }
enc.Encode(progressUpdateJSON{Status: status})
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
} }
return nil
} }
} }()
return nil
} }
func decodeUserJSON[T any](r io.Reader) (T, error) { func decodeUserJSON[T any](r io.Reader) (T, error) {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"io/fs" "io/fs"
"net" "net"
@@ -159,6 +160,7 @@ var registryFS = sync.OnceValue(func() fs.FS {
// to \n when parsing the txtar on Windows. // to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n")) data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data) a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a) fsys, err := txtar.FS(a)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -177,7 +179,7 @@ func TestServerPull(t *testing.T) {
w.WriteHeader(404) w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`) io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default: default:
t.Logf("serving blob: %s", r.URL.Path) t.Logf("serving file: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r) modelsHandler.ServeHTTP(w, r)
} }
}) })
@@ -186,7 +188,7 @@ func TestServerPull(t *testing.T) {
t.Helper() t.Helper()
if got.Code != 200 { if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code) t.Fatalf("Code = %d; want 200", got.Code)
} }
gotlines := got.Body.String() gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines) t.Logf("got:\n%s", gotlines)
@@ -195,29 +197,35 @@ func TestServerPull(t *testing.T) {
want, unwanted := strings.CutPrefix(want, "!") want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) { if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("! missing %q in body", want) t.Fatalf("! missing %q in body", want)
} }
if unwanted && strings.Contains(gotlines, want) { if unwanted && strings.Contains(gotlines, want) {
t.Errorf("! unexpected %q in body", want) t.Fatalf("! unexpected %q in body", want)
} }
} }
} }
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, ` checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} {"status":"pulling manifest"}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"} {"status":"error: model \"unknown\" not found"}
`) `)
@@ -232,39 +240,19 @@ func TestServerPull(t *testing.T) {
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`) got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""} {"status":"error: invalid or missing name: \"\""}
`)
// Non-streaming pulls !verifying
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`) !writing
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") !success
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
checkResponse(got, `
{"status":"success"}
!digest
!total
!completed
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
checkErrorResponse(t, got, 404, "not_found", "model not found")
} }
func TestServerUnknownPath(t *testing.T) { func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil) s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`) got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found") checkErrorResponse(t, got, 404, "not_found", "not found")
var fellback bool
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fellback = true
})
got = s.send(t, "DELETE", "/api/unknown", `{}`)
if !fellback {
t.Fatal("expected Fallback to be called")
}
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
} }
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) { func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {

View File

@@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m) isMllama := checkMllamaModelFamily(m)
isGemma3 := checkGemma3ModelFamily(m)
var imageNumTokens int var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- { for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 { if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages return "", nil, errTooManyImages
} }
@@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
} }
return false return false
} }
func checkGemma3ModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "gemma3" {
return true
}
}
return false
}