Compare commits
27 Commits
v0.6.1-rc0
...
parth/samp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0de5bbd0fe | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 | ||
|
|
e27e4a3c1b | ||
|
|
088514bbd4 | ||
|
|
2c8b484643 | ||
|
|
8294676150 | ||
|
|
ef378ad673 | ||
|
|
2d2247e59e | ||
|
|
7bf793a600 | ||
|
|
282bfaaa95 | ||
|
|
9679f40146 | ||
|
|
3892c3a703 | ||
|
|
4e320b8b90 | ||
|
|
eb2b22b042 |
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -392,6 +392,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
129
cmd/cmd_test.go
129
cmd/cmd_test.go
@@ -757,3 +757,132 @@ func TestCreateHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCreateRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
from string
|
||||
opts runOptions
|
||||
expected *api.CreateRequest
|
||||
}{
|
||||
{
|
||||
"basic test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "",
|
||||
Prompt: "You are a fun AI agent",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "/some/file/like/etc/passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as windows filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"options test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
Parameters: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"messages test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := NewCreateRequest(tt.from, tt.opts)
|
||||
if !cmp.Equal(actual, tt.expected) {
|
||||
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||
parentModel := opts.ParentModel
|
||||
|
||||
modelName := model.ParseName(parentModel)
|
||||
if !modelName.IsValid() {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Name: name,
|
||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
||||
Model: name,
|
||||
From: cmp.Or(parentModel, opts.Model),
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
|
||||
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
default:
|
||||
return errors.New("unsupported architecture")
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
|
||||
@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||
|
||||
19
llama/llama.cpp/src/llama-arch.cpp
vendored
19
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
|
||||
1
llama/llama.cpp/src/llama-arch.h
vendored
1
llama/llama.cpp/src/llama-arch.h
vendored
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
|
||||
7
llama/llama.cpp/src/llama-model.cpp
vendored
7
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
|
||||
9
llama/llama.cpp/src/llama-quant.cpp
vendored
9
llama/llama.cpp/src/llama-quant.cpp
vendored
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize &= name.find("v.blk.") == std::string::npos;
|
||||
|
||||
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
||||
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
||||
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
||||
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
||||
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
|
||||
113
llama/patches/0021-gemma3-quantization.patch
Normal file
113
llama/patches/0021-gemma3-quantization.patch
Normal file
@@ -0,0 +1,113 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Patrick Devine <patrick@infrahq.com>
|
||||
Date: Fri, 14 Mar 2025 16:33:23 -0700
|
||||
Subject: [PATCH] gemma3 quantization
|
||||
|
||||
---
|
||||
src/llama-arch.cpp | 19 +++++++++++++++++++
|
||||
src/llama-arch.h | 1 +
|
||||
src/llama-model.cpp | 7 +++++++
|
||||
src/llama-quant.cpp | 9 +++++++++
|
||||
4 files changed, 36 insertions(+)
|
||||
|
||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||
index b6f20286..b443fcd3 100644
|
||||
--- a/src/llama-arch.cpp
|
||||
+++ b/src/llama-arch.cpp
|
||||
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
+ { LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
+ {
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
+ {
|
||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
+ },
|
||||
+ },
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||
index ec742224..aad92a5d 100644
|
||||
--- a/src/llama-arch.h
|
||||
+++ b/src/llama-arch.h
|
||||
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index ab1a07d1..70183041 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
|
||||
index 6eb1da08..d2f3a510 100644
|
||||
--- a/src/llama-quant.cpp
|
||||
+++ b/src/llama-quant.cpp
|
||||
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
+ // don't quantize vision stuff
|
||||
+ quantize &= name.find("v.blk.") == std::string::npos;
|
||||
+
|
||||
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
||||
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
||||
+
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
134
llm/server.go
134
llm/server.go
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||
}
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
slog.Info("starting llama server", "cmd", s.cmd)
|
||||
if envconfig.Debug() {
|
||||
filteredEnv := []string{}
|
||||
for _, ev := range s.cmd.Env {
|
||||
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
func (s ServerStatus) String() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "llm server ready"
|
||||
@@ -485,11 +485,8 @@ func (s ServerStatus) ToString() string {
|
||||
}
|
||||
}
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
type ServerStatusResponse struct {
|
||||
Status ServerStatus `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
}
|
||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||
}
|
||||
|
||||
var status ServerStatusResp
|
||||
if err := json.Unmarshal(body, &status); err != nil {
|
||||
var ssr ServerStatusResponse
|
||||
if err := json.Unmarshal(body, &ssr); err != nil {
|
||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvailable, nil
|
||||
case "loading model":
|
||||
s.loadProgress = status.Progress
|
||||
return ServerStatusLoadingModel, nil
|
||||
switch ssr.Status {
|
||||
case ServerStatusLoadingModel:
|
||||
s.loadProgress = ssr.Progress
|
||||
return ssr.Status, nil
|
||||
case ServerStatusReady, ServerStatusNoSlotsAvailable:
|
||||
return ssr.Status, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
status, _ := s.getServerStatus(ctx)
|
||||
if lastStatus != status && status != ServerStatusReady {
|
||||
// Only log on status changes
|
||||
slog.Info("waiting for server to become available", "status", status.ToString())
|
||||
slog.Info("waiting for server to become available", "status", status)
|
||||
}
|
||||
switch status {
|
||||
case ServerStatusReady:
|
||||
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status)
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
fullyLoaded = true
|
||||
}
|
||||
@@ -671,63 +666,26 @@ type ImageData struct {
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
DoneReason string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
Content string `json:"content"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": req.Options.NumPredict,
|
||||
"n_keep": req.Options.NumKeep,
|
||||
"main_gpu": req.Options.MainGPU,
|
||||
"temperature": req.Options.Temperature,
|
||||
"top_k": req.Options.TopK,
|
||||
"top_p": req.Options.TopP,
|
||||
"min_p": req.Options.MinP,
|
||||
"typical_p": req.Options.TypicalP,
|
||||
"repeat_last_n": req.Options.RepeatLastN,
|
||||
"repeat_penalty": req.Options.RepeatPenalty,
|
||||
"presence_penalty": req.Options.PresencePenalty,
|
||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||
"mirostat": req.Options.Mirostat,
|
||||
"mirostat_tau": req.Options.MirostatTau,
|
||||
"mirostat_eta": req.Options.MirostatEta,
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
@@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
req.Grammar = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
@@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
@@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
@@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
if err := enc.Encode(req); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
@@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
evt = line
|
||||
}
|
||||
|
||||
var c completion
|
||||
var c CompletionResponse
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
@@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
doneReason := "stop"
|
||||
if c.StoppedLimit {
|
||||
doneReason = "length"
|
||||
}
|
||||
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
@@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
|
||||
@@ -312,23 +312,25 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
bts := C.malloc(C.size_t(t.Size()))
|
||||
if bts == nil {
|
||||
return errors.New("failed to allocate tensor buffer")
|
||||
}
|
||||
defer C.free(bts)
|
||||
|
||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
||||
if err != nil || n != len(buf) {
|
||||
return errors.New("read failed")
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if g.Wait() != nil {
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
|
||||
@@ -15,6 +15,12 @@ type Input struct {
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
|
||||
@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
|
||||
@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -211,8 +211,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -2,10 +2,9 @@ package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
type imageToken struct {
|
||||
embedding ml.Tensor
|
||||
index int
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
imageInputs := []input.Input{
|
||||
{Token: 108}, // "\n\n"
|
||||
{Token: 255999}, // "<start_of_image>""
|
||||
}
|
||||
result = append(result, imageInputs...)
|
||||
|
||||
// add image embeddings
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
for i := range inputMultimodal.Dim(1) {
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
||||
fnvHash.Write([]byte{byte(i)})
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
||||
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
||||
}
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
@@ -164,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||
var embedding ml.Tensor
|
||||
var src, dst, length int
|
||||
var except []int
|
||||
|
||||
for _, image := range multimodal {
|
||||
imageToken := image.Multimodal.(imageToken)
|
||||
imageSrc := imageToken.index
|
||||
imageDst := image.Index
|
||||
|
||||
if embedding == nil {
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length++
|
||||
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
||||
length++
|
||||
} else {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
}
|
||||
|
||||
except = append(except, imageDst)
|
||||
}
|
||||
|
||||
if embedding != nil {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
}
|
||||
|
||||
return except
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||
// set image embeddings
|
||||
var except []int
|
||||
for _, image := range opts.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
|
||||
@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var images []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
@@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if len(opts.Multimodal) > 0 {
|
||||
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
if len(images) > 0 {
|
||||
crossAttentionStates = images[len(images)-1]
|
||||
}
|
||||
}
|
||||
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
@@ -151,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
)
|
||||
|
||||
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// generating image embeddings for each image
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -229,7 +230,7 @@ type Server struct {
|
||||
image *ImageContext
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var samplingParams llama.SamplingParams
|
||||
samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
// Extract options from the CompletionRequest
|
||||
samplingParams := llama.SamplingParams{
|
||||
TopK: req.Options.TopK,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TypicalP: req.Options.TypicalP,
|
||||
Temp: req.Options.Temperature,
|
||||
RepeatLastN: req.Options.RepeatLastN,
|
||||
PenaltyRepeat: req.Options.RepeatPenalty,
|
||||
PenaltyFreq: req.Options.FrequencyPenalty,
|
||||
PenaltyPresent: req.Options.PresencePenalty,
|
||||
Mirostat: req.Options.Mirostat,
|
||||
MirostatTau: req.Options.MirostatTau,
|
||||
MirostatEta: req.Options.MirostatEta,
|
||||
Seed: uint32(req.Options.Seed),
|
||||
Grammar: req.Grammar,
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: req.NumKeep,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numDecoded,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
doneReason := "stop"
|
||||
if seq.doneReason == "limit" {
|
||||
doneReason = "length"
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
var req EmbeddingRequest
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -937,7 +852,7 @@ func Execute(args []string) error {
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
|
||||
@@ -89,7 +89,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -107,10 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
||||
@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
}{
|
||||
{
|
||||
name: "Basic cache hit - single user",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Basic cache hit - multi user",
|
||||
cache: InputCache{
|
||||
multiUserCache: true,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Exact match - leave one input",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
},
|
||||
{
|
||||
name: "No available slots",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expected an error
|
||||
}
|
||||
|
||||
// Verify slot ID
|
||||
if slot.Id != tt.expectedSlotId {
|
||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||
}
|
||||
|
||||
// Verify slot is now marked in use
|
||||
if !slot.InUse {
|
||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||
}
|
||||
|
||||
// Verify remaining prompt length
|
||||
if len(remainingPrompt) != tt.expectedPrompt {
|
||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||
len(remainingPrompt), tt.expectedPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -33,10 +34,14 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type contextList struct {
|
||||
list []ml.Context
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
||||
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
||||
// multimodal embeddings
|
||||
ctx ml.Context
|
||||
ctxs *contextList
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
@@ -94,13 +99,12 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
ctx := s.model.Backend().NewContext()
|
||||
|
||||
inputs, err := s.inputs(ctx, prompt, images)
|
||||
inputs, ctxs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@@ -111,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||
// otherwise we might truncate or split the batch against the model's wishes
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
@@ -126,7 +133,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctx: ctx,
|
||||
ctxs: ctxs,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@@ -145,7 +152,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
||||
var inputs []input.Input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -160,12 +167,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
var contexts contextList
|
||||
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
||||
for _, ctx := range ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}, contexts.list)
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
@@ -185,12 +199,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
contexts.list = append(contexts.list, ctx)
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.multimodalHash.Reset()
|
||||
@@ -204,13 +220,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
|
||||
if visionModel && postTokenize {
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
return inputs, &contexts, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@@ -222,7 +238,7 @@ type Server struct {
|
||||
model model.Model
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -305,7 +321,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
seq.ctx.Close()
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
@@ -351,20 +366,33 @@ func (s *Server) processBatch() error {
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if j >= s.batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
@@ -501,75 +529,18 @@ func (s *Server) processBatch() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -590,19 +561,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
grammar,
|
||||
)
|
||||
sampler := sample.NewSampler(req.Options, grammar)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -625,7 +589,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -652,7 +616,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -663,15 +627,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numPredicted,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
doneReason := "stop"
|
||||
if seq.doneReason == "limit" {
|
||||
doneReason = "length"
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -682,43 +648,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -772,7 +705,7 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -824,7 +757,7 @@ func Execute(args []string) error {
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
@@ -26,6 +28,10 @@ type Sampler struct {
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
tokens[i].id = int32(i)
|
||||
@@ -87,19 +93,13 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
tokens = topK(tokens, s.topK)
|
||||
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
tokens = softmax(tokens)
|
||||
// scale and normalize the tokens in place
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
// TODO: this should fall back to greedy sampling
|
||||
// or topP, topK values etc should be such that
|
||||
// there are always tokens to sample from
|
||||
if len(tokens) == 0 {
|
||||
return token{}, errors.New("no tokens to sample from")
|
||||
}
|
||||
|
||||
var r float32
|
||||
if s.rng != nil {
|
||||
r = s.rng.Float32()
|
||||
@@ -122,43 +122,71 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
return 1
|
||||
})
|
||||
|
||||
if math.IsNaN(float64(sum)) {
|
||||
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||
}
|
||||
return tokens[idx], nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
// SamplerParams contains the validated and normalized parameters for a sampler
|
||||
type SamplerParams struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
Seed int `json:"seed"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
|
||||
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
|
||||
type rawParams SamplerParams
|
||||
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate and normalize after unmarshaling
|
||||
if p.Temperature < 0.0 {
|
||||
p.Temperature = 0.0
|
||||
}
|
||||
|
||||
if p.TopP < 0.0 {
|
||||
p.TopP = 0.0
|
||||
}
|
||||
if p.TopP >= 1.0 {
|
||||
p.TopP = 1.0
|
||||
}
|
||||
|
||||
if p.MinP < 0.0 {
|
||||
p.MinP = 0.0
|
||||
}
|
||||
if p.MinP >= 1.0 {
|
||||
p.MinP = 1.0
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSampler creates a new sampler with the given options
|
||||
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
|
||||
var params SamplerParams
|
||||
data, _ := json.Marshal(opts)
|
||||
_ = json.Unmarshal(data, ¶ms)
|
||||
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
if params.Seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
// Use original seed for sequence
|
||||
sequence := uint64(seed)
|
||||
sequence := uint64(params.Seed)
|
||||
// Use golden ratio hash to generate statistically independent seeds
|
||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||
}
|
||||
if temperature < 0.0 {
|
||||
temperature = 0.0
|
||||
}
|
||||
|
||||
if topP < 0.0 {
|
||||
topP = 0.0
|
||||
}
|
||||
if topP >= 1.0 {
|
||||
topP = 1.0
|
||||
}
|
||||
|
||||
if minP < 0.0 {
|
||||
minP = 0.0
|
||||
}
|
||||
if minP >= 1.0 {
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
topK: params.TopK,
|
||||
topP: params.TopP,
|
||||
minP: params.MinP,
|
||||
temperature: params.Temperature,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
sampler.Sample(logits)
|
||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
||||
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
|
||||
return &api.Options{
|
||||
Temperature: temperature,
|
||||
TopK: topK,
|
||||
TopP: topP,
|
||||
MinP: minP,
|
||||
Seed: seed,
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -19,7 +32,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -29,12 +42,35 @@ func TestWeighted(t *testing.T) {
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// Should get the token with the highest logit
|
||||
want = int32(0)
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
||||
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) []token {
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
}
|
||||
}
|
||||
threshold := maxProb * p
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
tokens := toTokens(input)
|
||||
temperature(tokens, 0.5)
|
||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||
compareLogits(t, "temperature(0.5)", want, got)
|
||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 1.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 1.0)
|
||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
compareLogits(t, "temperature(1)", want, got)
|
||||
compareLogits(t, "temperature(1)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 0.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 0.0)
|
||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||
compareLogits(t, "temperature(0)", want, got)
|
||||
compareLogits(t, "temperature(0)", want, tokens)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := softmax(toTokens(tt.input))
|
||||
tokens := toTokens(tt.input)
|
||||
softmax(tokens)
|
||||
|
||||
if tt.expected != nil {
|
||||
compareLogits(t, tt.name, tt.expected, got)
|
||||
compareLogits(t, tt.name, tt.expected, tokens)
|
||||
return
|
||||
}
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
for _, token := range tokens {
|
||||
sum += token.value
|
||||
if token.value < 0 || token.value > 1 {
|
||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
|
||||
// Test k=5
|
||||
got := topK(toTokens(input), 5)
|
||||
if len(got) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
||||
tokens := toTokens(input)
|
||||
tokens = topK(tokens, 5)
|
||||
if len(tokens) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||
}
|
||||
// Should keep highest 3 values in descending order
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
compareLogits(t, "topK(3)", want, tokens)
|
||||
|
||||
got = topK(toTokens(input), 20)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
|
||||
// Test k=-1
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), -1)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, -1)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
// Test k=0
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), 0)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 0)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 1)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topK should keep at least one token")
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
@@ -153,50 +165,134 @@ func TestTopP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = softmax(tokens)
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
// Test with very high p value
|
||||
got := topP(tokens, 1.0)
|
||||
|
||||
// Should keep all tokens since p is 1
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 0.0)
|
||||
if len(got) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
got = topP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != 1 {
|
||||
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 1e-10)
|
||||
if len(got) == 0 {
|
||||
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
||||
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with single token
|
||||
tokens = toTokens(input[:1])
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.1)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
got := minP(tokens, 1.0)
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortLogits(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
tokens := toTokens(input)
|
||||
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
for i := 1; i < len(tokens); i++ {
|
||||
if tokens[i].value > tokens[i-1].value {
|
||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
||||
i, tokens[i].value, tokens[i-1].value)
|
||||
}
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
compareLogits(t, "sortLogits", want, tokens)
|
||||
// Test with zero p value
|
||||
got = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != len(tokens) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
@@ -231,7 +327,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
tokens = topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -239,7 +335,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
tokens = topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -247,7 +343,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
tokens = minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -255,7 +351,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 200000)
|
||||
tokens = topK(tokensCopy, 200000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ usage() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
export VERSION=${VERSION:-$(git describe --tags --dirty)}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'
|
||||
|
||||
|
||||
44
server/internal/cache/blob/cache.go
vendored
44
server/internal/cache/blob/cache.go
vendored
@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
|
||||
// be in either of the following forms:
|
||||
//
|
||||
// @<digest>
|
||||
// <name>
|
||||
// <name>@<digest>
|
||||
// <name>
|
||||
//
|
||||
// If a digest is provided, it is returned as is and nothing else happens.
|
||||
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
|
||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||
// these cases.
|
||||
//
|
||||
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
|
||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||
name, digest := splitNameDigest(name)
|
||||
if digest != "" {
|
||||
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
// TODO(bmizerany): Move link handling from cache to registry.
|
||||
//
|
||||
// We originally placed links in the cache due to its storage
|
||||
// knowledge. However, the registry likely offers better context for
|
||||
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||
// our on-disk format.
|
||||
//
|
||||
// Links work effectively when independent from physical location -
|
||||
// they can reference content with matching SHA regardless of storage
|
||||
// location. In an upcoming change, we plan to shift this
|
||||
// responsibility to the registry where it better aligns with the
|
||||
// system's conceptual model.
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a sequence of links in the cache in lexical order.
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
@@ -414,11 +402,13 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
d Digest
|
||||
size int64
|
||||
n int64
|
||||
h hash.Hash
|
||||
d Digest
|
||||
f *os.File
|
||||
h hash.Hash
|
||||
|
||||
w io.Writer // underlying writer; set by creator
|
||||
n int64
|
||||
err error
|
||||
|
||||
testHookBeforeFinalWrite func(*os.File)
|
||||
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
|
||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||
// hash.
|
||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
_, err := w.h.Write(p)
|
||||
if err != nil {
|
||||
return 0, w.seterr(err)
|
||||
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if nextSize > w.size {
|
||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||
}
|
||||
n, err := w.f.Write(p)
|
||||
n, err := w.w.Write(p)
|
||||
w.n += int64(n)
|
||||
return n, w.seterr(err)
|
||||
}
|
||||
@@ -497,6 +491,8 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
w: f,
|
||||
|
||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||
}
|
||||
n, err := io.Copy(cw, file)
|
||||
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
|
||||
var errInvalidName = errors.New("invalid name")
|
||||
|
||||
func nameToPath(name string) (_ string, err error) {
|
||||
if strings.Contains(name, "@") {
|
||||
// TODO(bmizerany): HACK: Fix names.Parse to validate.
|
||||
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
|
||||
return "", errInvalidName
|
||||
}
|
||||
n := names.Parse(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", errInvalidName
|
||||
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
|
||||
func absJoin(pp ...string) string {
|
||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||
if err != nil {
|
||||
// Likely a bug bug or a bad OS problem. Just panic.
|
||||
panic(err)
|
||||
panic(err) // this should never happen
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
73
server/internal/cache/blob/chunked.go
vendored
Normal file
73
server/internal/cache/blob/chunked.go
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Chunk represents a range of bytes in a blob.
|
||||
type Chunk struct {
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// Chunker writes to a blob in chunks.
|
||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
||||
type Chunker struct {
|
||||
digest Digest
|
||||
size int64
|
||||
f *os.File // nil means pre-validated
|
||||
}
|
||||
|
||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
||||
// size in chunks.
|
||||
//
|
||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
||||
name := c.GetFile(d)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
return &Chunker{}, nil
|
||||
}
|
||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Chunker{digest: d, size: size, f: f}, nil
|
||||
}
|
||||
|
||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
||||
// merging the data with the existing blob. It returns an error if any. As a
|
||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
||||
// io.ErrUnexpectedEOF.
|
||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
||||
if c.f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cw := &checkWriter{
|
||||
d: d,
|
||||
size: chunk.Size(),
|
||||
h: sha256.New(),
|
||||
f: c.f,
|
||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
||||
}
|
||||
|
||||
_, err := io.CopyN(cw, r, chunk.Size())
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (c *Chunker) Close() error {
|
||||
return c.f.Close()
|
||||
}
|
||||
4
server/internal/cache/blob/digest.go
vendored
4
server/internal/cache/blob/digest.go
vendored
@@ -63,6 +63,10 @@ func (d Digest) Short() string {
|
||||
return fmt.Sprintf("%x", d.sum[:4])
|
||||
}
|
||||
|
||||
func (d Digest) Sum() [32]byte {
|
||||
return d.sum
|
||||
}
|
||||
|
||||
func (d Digest) Compare(other Digest) int {
|
||||
return slices.Compare(d.sum[:], other.sum[:])
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Chunk struct {
|
||||
Start, End int64
|
||||
}
|
||||
|
||||
func New(start, end int64) Chunk {
|
||||
return Chunk{start, end}
|
||||
}
|
||||
|
||||
// ParseRange parses a string in the form "unit=range" where unit is a string
|
||||
// and range is a string in the form "start-end". It returns the unit and the
|
||||
// range as a Chunk.
|
||||
func ParseRange(s string) (unit string, _ Chunk, _ error) {
|
||||
unit, r, _ := strings.Cut(s, "=")
|
||||
if r == "" {
|
||||
return unit, Chunk{}, nil
|
||||
}
|
||||
c, err := Parse(r)
|
||||
if err != nil {
|
||||
return "", Chunk{}, err
|
||||
}
|
||||
return unit, c, err
|
||||
}
|
||||
|
||||
// Parse parses a string in the form "start-end" and returns the Chunk.
|
||||
func Parse(s string) (Chunk, error) {
|
||||
startStr, endStr, _ := strings.Cut(s, "-")
|
||||
start, err := strconv.ParseInt(startStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid start: %v", err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid end: %v", err)
|
||||
}
|
||||
if start > end {
|
||||
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
|
||||
}
|
||||
return Chunk{start, end}, nil
|
||||
}
|
||||
|
||||
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
|
||||
// the range [0, size), in order.
|
||||
func Of(size, chunkSize int64) iter.Seq[Chunk] {
|
||||
return func(yield func(Chunk) bool) {
|
||||
for start := int64(0); start < size; start += chunkSize {
|
||||
end := min(start+chunkSize-1, size-1)
|
||||
if !yield(Chunk{start, end}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of Chunks of size chunkSize needed to cover the
|
||||
// range [0, size).
|
||||
func Count(size, chunkSize int64) int64 {
|
||||
return (size + chunkSize - 1) / chunkSize
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// String returns the string representation of the Chunk in the form
|
||||
// "{start}-{end}".
|
||||
func (c Chunk) String() string {
|
||||
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOf(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want []Chunk
|
||||
}{
|
||||
{0, 1, nil},
|
||||
{1, 1, []Chunk{{0, 0}}},
|
||||
{1, 2, []Chunk{{0, 0}}},
|
||||
{2, 1, []Chunk{{0, 0}, {1, 1}}},
|
||||
{10, 9, []Chunk{{0, 8}, {9, 9}}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := slices.Collect(Of(tt.total, tt.chunkSize))
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
cases := []struct {
|
||||
c Chunk
|
||||
want int64
|
||||
}{
|
||||
{Chunk{0, 0}, 1},
|
||||
{Chunk{0, 1}, 2},
|
||||
{Chunk{3, 4}, 2},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.c.Size()
|
||||
if got != tt.want {
|
||||
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want int64
|
||||
}{
|
||||
{0, 1, 0},
|
||||
{1, 1, 1},
|
||||
{1, 2, 1},
|
||||
{2, 1, 2},
|
||||
{10, 9, 2},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
got := Count(tt.total, tt.chunkSize)
|
||||
if got != tt.want {
|
||||
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,11 +19,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,10 +37,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||
"github.com/ollama/ollama/server/internal/internal/names"
|
||||
"github.com/ollama/ollama/server/internal/internal/syncs"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
@@ -66,12 +65,7 @@ var (
|
||||
const (
|
||||
// DefaultChunkingThreshold is the threshold at which a layer should be
|
||||
// split up into chunks when downloading.
|
||||
DefaultChunkingThreshold = 128 << 20
|
||||
|
||||
// DefaultMaxChunkSize is the default maximum size of a chunk to
|
||||
// download. It is configured based on benchmarks and aims to strike a
|
||||
// balance between download speed and memory usage.
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
DefaultChunkingThreshold = 64 << 20
|
||||
)
|
||||
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
@@ -211,20 +205,13 @@ type Registry struct {
|
||||
// pushing or pulling models. If zero, the number of streams is
|
||||
// determined by [runtime.GOMAXPROCS].
|
||||
//
|
||||
// Clients that want "unlimited" streams should set this to a large
|
||||
// number.
|
||||
// A negative value means no limit.
|
||||
MaxStreams int
|
||||
|
||||
// ChunkingThreshold is the maximum size of a layer to download in a single
|
||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||
ChunkingThreshold int64
|
||||
|
||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
||||
// the default is [DefaultMaxChunkSize].
|
||||
//
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
@@ -266,6 +253,7 @@ func DefaultRegistry() (*Registry, error) {
|
||||
}
|
||||
|
||||
var rc Registry
|
||||
rc.UserAgent = UserAgent()
|
||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -281,25 +269,24 @@ func DefaultRegistry() (*Registry, error) {
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
func UserAgent() string {
|
||||
buildinfo, _ := debug.ReadBuildInfo()
|
||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||
buildinfo.Main.Version,
|
||||
runtime.GOARCH,
|
||||
runtime.GOOS,
|
||||
runtime.Version(),
|
||||
)
|
||||
}
|
||||
|
||||
// Large downloads require a writter stream, so ensure we have at least
|
||||
// two streams to avoid a deadlock.
|
||||
return max(n, 2)
|
||||
func (r *Registry) maxStreams() int {
|
||||
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
}
|
||||
|
||||
func (r *Registry) maxChunkingThreshold() int64 {
|
||||
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
||||
}
|
||||
|
||||
// chunkSizeFor returns the chunk size for a layer of the given size. If the
|
||||
// size is less than or equal to the max chunking threshold, the size is
|
||||
// returned; otherwise, the max chunk size is returned.
|
||||
func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
@@ -426,6 +413,21 @@ func canRetry(err error) bool {
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
// It always calls update with a nil error.
|
||||
type trackingReader struct {
|
||||
r io.Reader
|
||||
n *atomic.Int64
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
return
|
||||
}
|
||||
|
||||
// Pull pulls the model with the given name from the remote registry into the
|
||||
// cache.
|
||||
//
|
||||
@@ -434,15 +436,15 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.Resolve(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||
// server-side we special case '{}' to have some special meaning? Maybe
|
||||
// "archiving" a tag (which is how we reason about it in the registry
|
||||
// already, just with a different twist).
|
||||
if len(m.Layers) == 0 {
|
||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||
}
|
||||
@@ -452,142 +454,88 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
exists := func(l *Layer) bool {
|
||||
info, err := c.Get(l.Digest)
|
||||
return err == nil && info.Size == l.Size
|
||||
}
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
|
||||
// TODO(bmizerany): work to remove the need to do this
|
||||
layers := m.Layers
|
||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||
layers = append(layers, m.Config)
|
||||
}
|
||||
|
||||
// Send initial layer trace events to allow clients to have an
|
||||
// understanding of work to be done before work starts.
|
||||
t := traceFromContext(ctx)
|
||||
for _, l := range layers {
|
||||
if exists(l) {
|
||||
t.update(l, 0, nil)
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
g.SetLimit(r.maxStreams())
|
||||
for _, l := range layers {
|
||||
info, err := c.Get(l.Digest)
|
||||
if err == nil && info.Size == l.Size {
|
||||
t.update(l, l.Size, ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
|
||||
req, err := r.newRequest(ctx, "GET", blobURL, nil)
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
t.update(l, 0, err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.update(l, 0, nil)
|
||||
|
||||
if l.Size <= r.maxChunkingThreshold() {
|
||||
g.Go(func() error {
|
||||
// TODO(bmizerany): retry/backoff like below in
|
||||
// the chunking case
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
err = c.Put(l.Digest, res.Body, l.Size)
|
||||
if err == nil {
|
||||
t.update(l, l.Size, nil)
|
||||
}
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
q := syncs.NewRelayReader()
|
||||
|
||||
g.Go(func() (err error) {
|
||||
defer func() { q.CloseWithError(err) }()
|
||||
return c.Put(l.Digest, q, l.Size)
|
||||
})
|
||||
// TODO(bmizerany): fix this unbounded use of defer
|
||||
defer chunked.Close()
|
||||
|
||||
var progress atomic.Int64
|
||||
|
||||
// We want to avoid extra round trips per chunk due to
|
||||
// redirects from the registry to the blob store, so
|
||||
// fire an initial request to get the final URL and
|
||||
// then use that URL for the chunk requests.
|
||||
req.Header.Set("Range", "bytes=0-0")
|
||||
res, err := sendRequest(r.client(), req)
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Bad chunksums response, update tracing
|
||||
// clients and then bail.
|
||||
t.update(l, progress.Load(), err)
|
||||
return err
|
||||
}
|
||||
res.Body.Close()
|
||||
req = res.Request.WithContext(req.Context())
|
||||
|
||||
wp := writerPool{size: r.maxChunkSize()}
|
||||
|
||||
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
ticket := q.Take()
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
q.CloseWithError(err)
|
||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||
}
|
||||
ticket.Close()
|
||||
t.update(l, progress.Load(), err)
|
||||
}()
|
||||
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := func() error {
|
||||
req := req.Clone(req.Context())
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
tw := wp.get()
|
||||
tw.Reset(ticket)
|
||||
defer wp.put(tw)
|
||||
// Count bytes towards progress, as they
|
||||
// arrive, so that our bytes piggyback other
|
||||
// chunk updates on completion.
|
||||
//
|
||||
// This tactic is enough to show "smooth"
|
||||
// progress given the current CLI client. In
|
||||
// the near future, the server should report
|
||||
// download rate since it knows better than a
|
||||
// client that is measuring rate based on
|
||||
// wall-clock time-since-last-update.
|
||||
body := &trackingReader{r: res.Body, n: &progress}
|
||||
|
||||
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||
if err != nil {
|
||||
return maybeUnexpectedEOF(err)
|
||||
}
|
||||
if err := tw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
total := progress.Add(chunk.Size())
|
||||
if total >= l.Size {
|
||||
q.Close()
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if !canRetry(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// store the manifest blob
|
||||
md := blob.DigestFromBytes(m.Data)
|
||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// commit the manifest with a link
|
||||
return c.Link(m.Name, md)
|
||||
}
|
||||
|
||||
@@ -615,8 +563,6 @@ type Manifest struct {
|
||||
Config *Layer `json:"config"`
|
||||
}
|
||||
|
||||
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
// Layer returns the layer with the given
|
||||
// digest, or nil if not found.
|
||||
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
@@ -643,10 +589,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
|
||||
// last phase of the commit which expects it, but does nothing
|
||||
// with it. This will be fixed in a future release of
|
||||
// ollama.com.
|
||||
Config *Layer `json:"config"`
|
||||
Config Layer `json:"config"`
|
||||
}{
|
||||
M: M(m),
|
||||
Config: &Layer{Digest: emptyDigest},
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
@@ -736,6 +681,132 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type chunksum struct {
|
||||
URL string
|
||||
Chunk blob.Chunk
|
||||
Digest blob.Digest
|
||||
}
|
||||
|
||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
||||
return func(yield func(chunksum, error) bool) {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
yield(cs, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
var size int64
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
} else if size != l.Size {
|
||||
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
||||
}
|
||||
return
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
size += chunk.Size()
|
||||
if size > l.Size {
|
||||
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
if !yield(cs, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) client() *http.Client {
|
||||
if r.HTTPClient != nil {
|
||||
return r.HTTPClient
|
||||
@@ -898,13 +969,6 @@ func checkData(url string) string {
|
||||
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
||||
}
|
||||
|
||||
func maybeUnexpectedEOF(err error) error {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
@@ -991,27 +1055,22 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
type writerPool struct {
|
||||
size int64 // set by the caller
|
||||
|
||||
mu sync.Mutex
|
||||
ws []*bufio.Writer
|
||||
}
|
||||
|
||||
func (p *writerPool) get() *bufio.Writer {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.ws) == 0 {
|
||||
return bufio.NewWriterSize(nil, int(p.size))
|
||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
||||
if !found {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
||||
}
|
||||
w := p.ws[len(p.ws)-1]
|
||||
p.ws = p.ws[:len(p.ws)-1]
|
||||
return w
|
||||
}
|
||||
|
||||
func (p *writerPool) put(w *bufio.Writer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
w.Reset(nil)
|
||||
p.ws = append(p.ws, w)
|
||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
||||
}
|
||||
if start > end {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
||||
}
|
||||
return blob.Chunk{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
@@ -71,7 +71,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
// communication is attempted.
|
||||
//
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
@@ -89,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
Transport: recordRoundTripper(upstreamRegistry),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -428,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
@@ -531,54 +531,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
rng := r.Header.Get("Range")
|
||||
if rng == "" {
|
||||
http.Error(w, "missing range", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
io.WriteString(w, "remote"[c.Start:c.End+1])
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
|
||||
})
|
||||
|
||||
// Force chunking by setting the threshold to less than the size of the
|
||||
// layer.
|
||||
rc.ChunkingThreshold = 3
|
||||
rc.MaxChunkSize = 3
|
||||
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
if err != nil {
|
||||
t.Errorf("update %v %d %v", d, n, err)
|
||||
}
|
||||
reads = append(reads, n)
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("reads = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
@@ -816,3 +768,74 @@ func TestUnlink(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPullChunksums(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
content := "hello"
|
||||
var chunksums string
|
||||
contentDigest := func() blob.Digest {
|
||||
return blob.DigestFromBytes(content)
|
||||
}
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||
w.Header().Set("Content-Location", loc)
|
||||
io.WriteString(w, chunksums)
|
||||
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||
default:
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||
|
||||
var mu sync.Mutex
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("Update: %v %d %v", l, n, err)
|
||||
mu.Lock()
|
||||
reads = append(reads, n)
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||
blob.DigestFromBytes("hel"),
|
||||
blob.DigestFromBytes("lo"),
|
||||
)
|
||||
err := rc.Pull(ctx, "test")
|
||||
check(err)
|
||||
if !slices.Equal(reads, []int64{0, 3, 5}) {
|
||||
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
|
||||
}
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "test")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("test")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
for i := range mg.Layers {
|
||||
_, err = c.Get(mg.Layers[i].Digest)
|
||||
if err != nil {
|
||||
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||
}
|
||||
}
|
||||
|
||||
// missing chunks
|
||||
content = "llama"
|
||||
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||
err = rc.Pull(ctx, "missingchunks")
|
||||
if err == nil {
|
||||
t.Error("expected error because of missing chunks")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func BenchmarkDownload(b *testing.B) {
|
||||
run := func(fileSize, chunkSize int64) {
|
||||
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
|
||||
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
|
||||
}
|
||||
|
||||
run(100<<20, 8<<20)
|
||||
run(100<<20, 16<<20)
|
||||
run(100<<20, 32<<20)
|
||||
run(100<<20, 64<<20)
|
||||
run(100<<20, 128<<20) // 1 chunk
|
||||
}
|
||||
|
||||
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
|
||||
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := c.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
|
||||
return err
|
||||
}
|
||||
|
||||
var sleepTime atomic.Int64
|
||||
|
||||
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
|
||||
client := &http.Client{
|
||||
Transport: func() http.RoundTripper {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DisableKeepAlives = true
|
||||
return tr
|
||||
}(),
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
// warm up the client
|
||||
run(context.Background(), client, chunks.New(0, 1<<20))
|
||||
|
||||
b.SetBytes(fileSize)
|
||||
b.ReportAllocs()
|
||||
|
||||
// Give our CDN a min to breathe between benchmarks.
|
||||
time.Sleep(time.Duration(sleepTime.Swap(3)))
|
||||
|
||||
for b.Loop() {
|
||||
g, ctx := errgroup.WithContext(b.Context())
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for chunk := range chunks.Of(fileSize, chunkSize) {
|
||||
g.Go(func() error { return run(ctx, client, chunk) })
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
|
||||
}
|
||||
|
||||
func benchmarkWrite(b *testing.B, chunkSize int) {
|
||||
b.ReportAllocs()
|
||||
|
||||
dir := b.TempDir()
|
||||
f, err := os.Create(filepath.Join(dir, "write-single"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data := make([]byte, chunkSize)
|
||||
b.SetBytes(int64(chunkSize))
|
||||
r := bytes.NewReader(data)
|
||||
for b.Loop() {
|
||||
r.Reset(data)
|
||||
_, err := io.Copy(f, r)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
// Package registry provides an http.Handler for handling local Ollama API
|
||||
// requests for performing tasks related to the ollama.com model registry and
|
||||
// the local disk cache.
|
||||
// Package registry implements an http.Handler for handling local Ollama API
|
||||
// model management requests. See [Local] for details.
|
||||
package registry
|
||||
|
||||
import (
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,16 +18,11 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
// Local is an http.Handler for handling local Ollama API requests for
|
||||
// performing tasks related to the ollama.com model registry combined with the
|
||||
// local disk cache.
|
||||
// Local implements an http.Handler for handling local Ollama API model
|
||||
// management requests, such as pushing, pulling, and deleting models.
|
||||
//
|
||||
// It is not concern of Local, or this package, to handle model creation, which
|
||||
// proceeds any registry operations for models it produces.
|
||||
//
|
||||
// NOTE: The package built for dealing with model creation should use
|
||||
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||
// directly to the blob disk cache.
|
||||
// It can be arranged for all unknown requests to be passed through to a
|
||||
// fallback handler, if one is provided.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Logger *slog.Logger // required
|
||||
@@ -63,6 +58,7 @@ func (e serverError) Error() string {
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
@@ -175,8 +171,16 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
}
|
||||
|
||||
type params struct {
|
||||
DeprecatedName string `json:"name"` // Use [params.model]
|
||||
Model string `json:"model"` // Use [params.model]
|
||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
||||
// but is deprecated. New clients should use [Model] instead.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
DeprecatedName string `json:"name"`
|
||||
|
||||
// Model is the name of the model to push, pull, or delete.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
Model string `json:"model"`
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
@@ -189,9 +193,18 @@ type params struct {
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||
// progress updates.
|
||||
ProgressStream bool `json:"stream"`
|
||||
// Stream, if true, will make the server send progress updates in a
|
||||
// streaming of JSON objects. If false, the server will send a single
|
||||
// JSON object with the final status as "success", or an error object
|
||||
// if an error occurred.
|
||||
//
|
||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||
// defined to default to true if not present, so we need a way to check
|
||||
// if the client decisively it to false. So, we use a pointer to a
|
||||
// bool. Gross.
|
||||
//
|
||||
// Use [stream()] to get the correct value for this field.
|
||||
Stream *bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
@@ -199,6 +212,13 @@ func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (p params) stream() bool {
|
||||
if p.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *p.Stream
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
@@ -212,16 +232,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
}
|
||||
if s.Prune == nil {
|
||||
return nil
|
||||
return errModelNotFound
|
||||
}
|
||||
if s.Prune != nil {
|
||||
return s.Prune()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Status string `json:"status"`
|
||||
Status string `json:"status,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
@@ -237,6 +257,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
if !p.stream() {
|
||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
return errModelNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
return enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
}
|
||||
|
||||
maybeFlush := func() {
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
@@ -246,69 +277,67 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
defer maybeFlush()
|
||||
|
||||
var mu sync.Mutex
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
progress := make(map[*ollama.Layer]int64)
|
||||
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||
pushUpdate := func() {
|
||||
defer maybeFlush()
|
||||
|
||||
// TODO(bmizerany): This scales poorly with more layers due to
|
||||
// needing to flush out them all in one big update. We _could_
|
||||
// just flush on the changed ones, or just track the whole
|
||||
// download. Needs more thought. This is fine for now.
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// TODO(bmizerany): coalesce these updates; writing per
|
||||
// update is expensive
|
||||
maps.Copy(progressCopy, progress)
|
||||
mu.Unlock()
|
||||
for l, n := range progress {
|
||||
enc.Encode(progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Status: "pulling",
|
||||
Total: l.Size,
|
||||
Completed: n,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
||||
start := sync.OnceFunc(func() {
|
||||
pushUpdate()
|
||||
t.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
start() // flush initial state
|
||||
}
|
||||
mu.Lock()
|
||||
progress[l] = n
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// TODO(bmizerany): continue to support non-streaming responses
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
}()
|
||||
|
||||
func() {
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
mu.Lock()
|
||||
maybeFlush()
|
||||
mu.Unlock()
|
||||
pushUpdate()
|
||||
case err := <-done:
|
||||
pushUpdate()
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
}
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// These final updates are not strictly necessary, because they have
|
||||
// already happened at this point. Our pull handler code used to do
|
||||
// these steps after, not during, the pull, and they were slow, so we
|
||||
// wanted to provide feedback to users what was happening. For now, we
|
||||
// keep them to not jar users who are used to seeing them. We can phase
|
||||
// them out with a new and nicer UX later. One without progress bars
|
||||
// and digests that no one cares about.
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
@@ -160,7 +159,6 @@ var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fmt.Printf("%q\n", a.Comment)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -179,7 +177,7 @@ func TestServerPull(t *testing.T) {
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving file: %s", r.URL.Path)
|
||||
t.Logf("serving blob: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
@@ -188,7 +186,7 @@ func TestServerPull(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
t.Errorf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
@@ -197,35 +195,29 @@ func TestServerPull(t *testing.T) {
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! missing %q in body", want)
|
||||
t.Errorf("! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! unexpected %q in body", want)
|
||||
t.Errorf("! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
{"status":"verifying layers"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
@@ -240,19 +232,39 @@ func TestServerPull(t *testing.T) {
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: invalid or missing name: \"\""}
|
||||
|
||||
!verifying
|
||||
!writing
|
||||
!success
|
||||
`)
|
||||
|
||||
// Non-streaming pulls
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
||||
checkResponse(got, `
|
||||
{"status":"success"}
|
||||
!digest
|
||||
!total
|
||||
!completed
|
||||
`)
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
var fellback bool
|
||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fellback = true
|
||||
})
|
||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
if !fellback {
|
||||
t.Fatal("expected Fallback to be called")
|
||||
}
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
|
||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
for _, layer := range layers {
|
||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err)
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
|
||||
@@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var system []api.Message
|
||||
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
isGemma3 := checkGemma3ModelFamily(m)
|
||||
|
||||
var imageNumTokens int
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
|
||||
if isMllama && len(msgs[i].Images) > 1 {
|
||||
return "", nil, errTooManyImages
|
||||
}
|
||||
|
||||
@@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkGemma3ModelFamily(m *Model) bool {
|
||||
for _, arch := range m.Config.ModelFamilies {
|
||||
if arch == "gemma3" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
{{- if eq .Role "user" }}<start_of_turn>user
|
||||
{{- if and (eq $i 1) $.System }}
|
||||
{{ $.System }}
|
||||
{{ end }}
|
||||
{{ .Content }}<end_of_turn>
|
||||
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||
{{ .Content }}<end_of_turn>
|
||||
{{ end }}
|
||||
{{- if $last }}<start_of_turn>model
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<end_of_turn>"
|
||||
],
|
||||
"temperature": 0.1
|
||||
}
|
||||
@@ -87,6 +87,10 @@
|
||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"name": "gemma-instruct"
|
||||
},
|
||||
{
|
||||
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||
"name": "gemma3-instruct"
|
||||
},
|
||||
{
|
||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||
"name": "llama3-instruct"
|
||||
|
||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
<start_of_turn>user
|
||||
You are a helpful assistant.
|
||||
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
I'm doing great. How can I help you today?<end_of_turn>
|
||||
<start_of_turn>user
|
||||
I'd like to show off how chat templating works!<end_of_turn>
|
||||
<start_of_turn>model
|
||||
|
||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
<start_of_turn>user
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
|
||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
<start_of_turn>user
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
I'm doing great. How can I help you today?<end_of_turn>
|
||||
<start_of_turn>user
|
||||
I'd like to show off how chat templating works!<end_of_turn>
|
||||
<start_of_turn>model
|
||||
|
||||
Reference in New Issue
Block a user