Compare commits
27 Commits
pdevine/lo
...
pdevine/bf
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c75b428249 | ||
![]() |
021dcf089d | ||
![]() |
bf24498b1e | ||
![]() |
95e271d98f | ||
![]() |
364629b8d6 | ||
![]() |
108fe02165 | ||
![]() |
4561fff36e | ||
![]() |
50b5962042 | ||
![]() |
e27e4a3c1b | ||
![]() |
088514bbd4 | ||
![]() |
2c8b484643 | ||
![]() |
8294676150 | ||
![]() |
ef378ad673 | ||
![]() |
2d2247e59e | ||
![]() |
7bf793a600 | ||
![]() |
282bfaaa95 | ||
![]() |
9679f40146 | ||
![]() |
3892c3a703 | ||
![]() |
4e320b8b90 | ||
![]() |
eb2b22b042 | ||
![]() |
4ea4d2b189 | ||
![]() |
8d76fa23ef | ||
![]() |
74b44fdf8f | ||
![]() |
65b88c544f | ||
![]() |
a422ba39c9 | ||
![]() |
d2ec22371e | ||
![]() |
033cec232a |
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
@@ -392,6 +392,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
|
||||
### Cloud
|
||||
|
||||
|
129
cmd/cmd_test.go
129
cmd/cmd_test.go
@@ -757,3 +757,132 @@ func TestCreateHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCreateRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
from string
|
||||
opts runOptions
|
||||
expected *api.CreateRequest
|
||||
}{
|
||||
{
|
||||
"basic test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "",
|
||||
Prompt: "You are a fun AI agent",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "/some/file/like/etc/passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as windows filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"options test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
Parameters: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"messages test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := NewCreateRequest(tt.from, tt.opts)
|
||||
if !cmp.Equal(actual, tt.expected) {
|
||||
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||
parentModel := opts.ParentModel
|
||||
|
||||
modelName := model.ParseName(parentModel)
|
||||
if !modelName.IsValid() {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Name: name,
|
||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
||||
Model: name,
|
||||
From: cmp.Or(parentModel, opts.Model),
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
|
@@ -11,9 +11,10 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/types/bfloat16"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
|
@@ -187,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
||||
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||
|
||||
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||
|
||||
```
|
||||
# Allow all Chrome, Firefox, and Safari extensions
|
||||
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||
```
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Where are models stored?
|
||||
|
@@ -583,39 +583,52 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
if llm.KV().Uint("vision.block_count") == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for name, layer := range llm.Tensors().GroupLayers() {
|
||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||
for _, tensor := range layer {
|
||||
weights += tensor.Size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||
if patchSize == 0 {
|
||||
slog.Warn("unknown patch size for vision model")
|
||||
return
|
||||
}
|
||||
|
||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||
|
||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
imageSize*imageSize*numChannels*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
case "gemma3":
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
}
|
||||
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
|
1
go.mod
1
go.mod
@@ -16,7 +16,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.6.0
|
||||
|
2
go.sum
2
go.sum
@@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||
|
19
llama/llama.cpp/src/llama-arch.cpp
vendored
19
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
|
1
llama/llama.cpp/src/llama-arch.h
vendored
1
llama/llama.cpp/src/llama-arch.h
vendored
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
|
7
llama/llama.cpp/src/llama-model.cpp
vendored
7
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
|
9
llama/llama.cpp/src/llama-quant.cpp
vendored
9
llama/llama.cpp/src/llama-quant.cpp
vendored
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize &= name.find("v.blk.") == std::string::npos;
|
||||
|
||||
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
||||
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
||||
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
||||
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
||||
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
|
113
llama/patches/0021-gemma3-quantization.patch
Normal file
113
llama/patches/0021-gemma3-quantization.patch
Normal file
@@ -0,0 +1,113 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Patrick Devine <patrick@infrahq.com>
|
||||
Date: Fri, 14 Mar 2025 16:33:23 -0700
|
||||
Subject: [PATCH] gemma3 quantization
|
||||
|
||||
---
|
||||
src/llama-arch.cpp | 19 +++++++++++++++++++
|
||||
src/llama-arch.h | 1 +
|
||||
src/llama-model.cpp | 7 +++++++
|
||||
src/llama-quant.cpp | 9 +++++++++
|
||||
4 files changed, 36 insertions(+)
|
||||
|
||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||
index b6f20286..b443fcd3 100644
|
||||
--- a/src/llama-arch.cpp
|
||||
+++ b/src/llama-arch.cpp
|
||||
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
+ { LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
+ {
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
+ {
|
||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
+ },
|
||||
+ },
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||
index ec742224..aad92a5d 100644
|
||||
--- a/src/llama-arch.h
|
||||
+++ b/src/llama-arch.h
|
||||
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index ab1a07d1..70183041 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
|
||||
index 6eb1da08..d2f3a510 100644
|
||||
--- a/src/llama-quant.cpp
|
||||
+++ b/src/llama-quant.cpp
|
||||
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
+ // don't quantize vision stuff
|
||||
+ quantize &= name.find("v.blk.") == std::string::npos;
|
||||
+
|
||||
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
||||
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
||||
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
||||
+
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
memoryWeights += layerSize
|
||||
|
||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||
),
|
||||
|
136
llm/server.go
136
llm/server.go
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||
}
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
slog.Info("starting llama server", "cmd", s.cmd)
|
||||
if envconfig.Debug() {
|
||||
filteredEnv := []string{}
|
||||
for _, ev := range s.cmd.Env {
|
||||
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
func (s ServerStatus) String() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "llm server ready"
|
||||
@@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
|
||||
}
|
||||
}
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
Progress float32 `json:"progress"`
|
||||
type ServerStatusResponse struct {
|
||||
Status ServerStatus `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
}
|
||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||
}
|
||||
|
||||
var status ServerStatusResp
|
||||
if err := json.Unmarshal(body, &status); err != nil {
|
||||
var ssr ServerStatusResponse
|
||||
if err := json.Unmarshal(body, &ssr); err != nil {
|
||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvailable, nil
|
||||
case "loading model":
|
||||
s.loadProgress = status.Progress
|
||||
return ServerStatusLoadingModel, nil
|
||||
switch ssr.Status {
|
||||
case ServerStatusLoadingModel:
|
||||
s.loadProgress = ssr.Progress
|
||||
return ssr.Status, nil
|
||||
case ServerStatusReady, ServerStatusNoSlotsAvailable:
|
||||
return ssr.Status, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
status, _ := s.getServerStatus(ctx)
|
||||
if lastStatus != status && status != ServerStatusReady {
|
||||
// Only log on status changes
|
||||
slog.Info("waiting for server to become available", "status", status.ToString())
|
||||
slog.Info("waiting for server to become available", "status", status)
|
||||
}
|
||||
switch status {
|
||||
case ServerStatusReady:
|
||||
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status)
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
fullyLoaded = true
|
||||
}
|
||||
@@ -671,63 +666,26 @@ type ImageData struct {
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
DoneReason string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
Content string `json:"content"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": req.Options.NumPredict,
|
||||
"n_keep": req.Options.NumKeep,
|
||||
"main_gpu": req.Options.MainGPU,
|
||||
"temperature": req.Options.Temperature,
|
||||
"top_k": req.Options.TopK,
|
||||
"top_p": req.Options.TopP,
|
||||
"min_p": req.Options.MinP,
|
||||
"typical_p": req.Options.TypicalP,
|
||||
"repeat_last_n": req.Options.RepeatLastN,
|
||||
"repeat_penalty": req.Options.RepeatPenalty,
|
||||
"presence_penalty": req.Options.PresencePenalty,
|
||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||
"mirostat": req.Options.Mirostat,
|
||||
"mirostat_tau": req.Options.MirostatTau,
|
||||
"mirostat_eta": req.Options.MirostatEta,
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
@@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
req.Grammar = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
@@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
@@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
@@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
if err := enc.Encode(req); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
@@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
evt = line
|
||||
}
|
||||
|
||||
var c completion
|
||||
var c CompletionResponse
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
@@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
doneReason := "stop"
|
||||
if c.StoppedLimit {
|
||||
doneReason = "length"
|
||||
}
|
||||
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
@@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
|
@@ -312,17 +312,19 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
bts := C.malloc(C.size_t(t.Size()))
|
||||
if bts == nil {
|
||||
return errors.New("failed to allocate tensor buffer")
|
||||
}
|
||||
defer C.free(bts)
|
||||
|
||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
||||
if err != nil || n != len(buf) {
|
||||
return errors.New("read failed")
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
|
@@ -15,6 +15,12 @@ type Input struct {
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
|
@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
|
@@ -2,10 +2,9 @@ package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
type imageToken struct {
|
||||
embedding ml.Tensor
|
||||
index int
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
imageInputs := []input.Input{
|
||||
{Token: 108}, // "\n\n"
|
||||
{Token: 255999}, // "<start_of_image>""
|
||||
}
|
||||
result = append(result, imageInputs...)
|
||||
|
||||
// add image embeddings
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
for i := range inputMultimodal.Dim(1) {
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
||||
fnvHash.Write([]byte{byte(i)})
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
||||
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
||||
}
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
|
@@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||
var embedding ml.Tensor
|
||||
var src, dst, length int
|
||||
var except []int
|
||||
|
||||
for _, image := range multimodal {
|
||||
imageToken := image.Multimodal.(imageToken)
|
||||
imageSrc := imageToken.index
|
||||
imageDst := image.Index
|
||||
|
||||
if embedding == nil {
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length++
|
||||
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
||||
length++
|
||||
} else {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
}
|
||||
|
||||
except = append(except, imageDst)
|
||||
}
|
||||
|
||||
if embedding != nil {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
}
|
||||
|
||||
return except
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||
// set image embeddings
|
||||
var except []int
|
||||
for _, image := range opts.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
|
@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var images []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
@@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if len(opts.Multimodal) > 0 {
|
||||
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
if len(images) > 0 {
|
||||
crossAttentionStates = images[len(images)-1]
|
||||
}
|
||||
}
|
||||
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
|
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
)
|
||||
|
||||
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// generating image embeddings for each image
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -229,7 +230,7 @@ type Server struct {
|
||||
image *ImageContext
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var samplingParams llama.SamplingParams
|
||||
samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
// Extract options from the CompletionRequest
|
||||
samplingParams := llama.SamplingParams{
|
||||
TopK: req.Options.TopK,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TypicalP: req.Options.TypicalP,
|
||||
Temp: req.Options.Temperature,
|
||||
RepeatLastN: req.Options.RepeatLastN,
|
||||
PenaltyRepeat: req.Options.RepeatPenalty,
|
||||
PenaltyFreq: req.Options.FrequencyPenalty,
|
||||
PenaltyPresent: req.Options.PresencePenalty,
|
||||
Mirostat: req.Options.Mirostat,
|
||||
MirostatTau: req.Options.MirostatTau,
|
||||
MirostatEta: req.Options.MirostatEta,
|
||||
Seed: uint32(req.Options.Seed),
|
||||
Grammar: req.Grammar,
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: req.NumKeep,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numDecoded,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
doneReason := "stop"
|
||||
if seq.doneReason == "limit" {
|
||||
doneReason = "length"
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
var req EmbeddingRequest
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -937,7 +852,7 @@ func Execute(args []string) error {
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
|
@@ -89,7 +89,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -107,10 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
}{
|
||||
{
|
||||
name: "Basic cache hit - single user",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Basic cache hit - multi user",
|
||||
cache: InputCache{
|
||||
multiUserCache: true,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Exact match - leave one input",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
},
|
||||
{
|
||||
name: "No available slots",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expected an error
|
||||
}
|
||||
|
||||
// Verify slot ID
|
||||
if slot.Id != tt.expectedSlotId {
|
||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||
}
|
||||
|
||||
// Verify slot is now marked in use
|
||||
if !slot.InUse {
|
||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||
}
|
||||
|
||||
// Verify remaining prompt length
|
||||
if len(remainingPrompt) != tt.expectedPrompt {
|
||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||
len(remainingPrompt), tt.expectedPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -24,6 +24,7 @@ import (
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -33,10 +34,14 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type contextList struct {
|
||||
list []ml.Context
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
||||
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
||||
// multimodal embeddings
|
||||
ctx ml.Context
|
||||
ctxs *contextList
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
@@ -94,13 +99,12 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
ctx := s.model.Backend().NewContext()
|
||||
|
||||
inputs, err := s.inputs(ctx, prompt, images)
|
||||
inputs, ctxs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@@ -111,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||
// otherwise we might truncate or split the batch against the model's wishes
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
@@ -126,7 +133,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctx: ctx,
|
||||
ctxs: ctxs,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@@ -145,7 +152,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
||||
var inputs []input.Input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -160,12 +167,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
var contexts contextList
|
||||
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
||||
for _, ctx := range ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}, contexts.list)
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
@@ -185,12 +199,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
contexts.list = append(contexts.list, ctx)
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.multimodalHash.Reset()
|
||||
@@ -204,13 +220,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
|
||||
if visionModel && postTokenize {
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
return inputs, &contexts, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@@ -222,7 +238,7 @@ type Server struct {
|
||||
model model.Model
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -305,7 +321,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
seq.ctx.Close()
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
@@ -351,22 +366,35 @@ func (s *Server) processBatch() error {
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if j >= s.batchSize {
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
@@ -501,75 +529,18 @@ func (s *Server) processBatch() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -591,18 +562,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -625,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -652,7 +623,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -663,15 +634,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numPredicted,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
doneReason := "stop"
|
||||
if seq.doneReason == "limit" {
|
||||
doneReason = "length"
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -682,43 +655,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -772,7 +712,7 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -824,7 +764,7 @@ func Execute(args []string) error {
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
|
@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
tokens = topK(tokens, s.topK)
|
||||
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
tokens = softmax(tokens)
|
||||
// scale and normalize the tokens in place
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) []token {
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
threshold := maxProb * p
|
||||
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
tokens := toTokens(input)
|
||||
temperature(tokens, 0.5)
|
||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||
compareLogits(t, "temperature(0.5)", want, got)
|
||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 1.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 1.0)
|
||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
compareLogits(t, "temperature(1)", want, got)
|
||||
compareLogits(t, "temperature(1)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 0.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 0.0)
|
||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||
compareLogits(t, "temperature(0)", want, got)
|
||||
compareLogits(t, "temperature(0)", want, tokens)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := softmax(toTokens(tt.input))
|
||||
tokens := toTokens(tt.input)
|
||||
softmax(tokens)
|
||||
|
||||
if tt.expected != nil {
|
||||
compareLogits(t, tt.name, tt.expected, got)
|
||||
compareLogits(t, tt.name, tt.expected, tokens)
|
||||
return
|
||||
}
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
for _, token := range tokens {
|
||||
sum += token.value
|
||||
if token.value < 0 || token.value > 1 {
|
||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
|
||||
// Test k=5
|
||||
got := topK(toTokens(input), 5)
|
||||
if len(got) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
||||
tokens := toTokens(input)
|
||||
tokens = topK(tokens, 5)
|
||||
if len(tokens) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||
}
|
||||
// Should keep highest 3 values in descending order
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
compareLogits(t, "topK(3)", want, tokens)
|
||||
|
||||
got = topK(toTokens(input), 20)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
|
||||
// Test k=-1
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), -1)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, -1)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
// Test k=0
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), 0)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 0)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 1)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topK should keep at least one token")
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = softmax(tokens)
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
tokens = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = topP(tokens, 0.0) // Very small p
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
tokens = topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
tokens = topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
tokens = minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 200000)
|
||||
tokens = topK(tokensCopy, 200000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -8,7 +8,7 @@ usage() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
export VERSION=${VERSION:-$(git describe --tags --dirty)}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'
|
||||
|
||||
|
54
server/internal/cache/blob/cache.go
vendored
54
server/internal/cache/blob/cache.go
vendored
@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
|
||||
// be in either of the following forms:
|
||||
//
|
||||
// @<digest>
|
||||
// <name>
|
||||
// <name>@<digest>
|
||||
// <name>
|
||||
//
|
||||
// If a digest is provided, it is returned as is and nothing else happens.
|
||||
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
|
||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||
// these cases.
|
||||
//
|
||||
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
|
||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||
name, digest := splitNameDigest(name)
|
||||
if digest != "" {
|
||||
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
// TODO(bmizerany): Move link handling from cache to registry.
|
||||
//
|
||||
// We originally placed links in the cache due to its storage
|
||||
// knowledge. However, the registry likely offers better context for
|
||||
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||
// our on-disk format.
|
||||
//
|
||||
// Links work effectively when independent from physical location -
|
||||
// they can reference content with matching SHA regardless of storage
|
||||
// location. In an upcoming change, we plan to shift this
|
||||
// responsibility to the registry where it better aligns with the
|
||||
// system's conceptual model.
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a sequence of links in the cache in lexical order.
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
@@ -414,12 +402,14 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
d Digest
|
||||
size int64
|
||||
n int64
|
||||
h hash.Hash
|
||||
d Digest
|
||||
f *os.File
|
||||
err error
|
||||
h hash.Hash
|
||||
|
||||
w io.Writer // underlying writer; set by creator
|
||||
n int64
|
||||
err error
|
||||
|
||||
testHookBeforeFinalWrite func(*os.File)
|
||||
}
|
||||
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
|
||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||
// hash.
|
||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
_, err := w.h.Write(p)
|
||||
if err != nil {
|
||||
return 0, w.seterr(err)
|
||||
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if nextSize > w.size {
|
||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||
}
|
||||
n, err := w.f.Write(p)
|
||||
n, err := w.w.Write(p)
|
||||
w.n += int64(n)
|
||||
return n, w.seterr(err)
|
||||
}
|
||||
@@ -493,10 +487,12 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
cw := &checkWriter{
|
||||
d: out,
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
d: out,
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
w: f,
|
||||
|
||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||
}
|
||||
n, err := io.Copy(cw, file)
|
||||
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
|
||||
var errInvalidName = errors.New("invalid name")
|
||||
|
||||
func nameToPath(name string) (_ string, err error) {
|
||||
if strings.Contains(name, "@") {
|
||||
// TODO(bmizerany): HACK: Fix names.Parse to validate.
|
||||
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
|
||||
return "", errInvalidName
|
||||
}
|
||||
n := names.Parse(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", errInvalidName
|
||||
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
|
||||
func absJoin(pp ...string) string {
|
||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||
if err != nil {
|
||||
// Likely a bug bug or a bad OS problem. Just panic.
|
||||
panic(err)
|
||||
panic(err) // this should never happen
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
73
server/internal/cache/blob/chunked.go
vendored
Normal file
73
server/internal/cache/blob/chunked.go
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Chunk represents a range of bytes in a blob.
|
||||
type Chunk struct {
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// Chunker writes to a blob in chunks.
|
||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
||||
type Chunker struct {
|
||||
digest Digest
|
||||
size int64
|
||||
f *os.File // nil means pre-validated
|
||||
}
|
||||
|
||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
||||
// size in chunks.
|
||||
//
|
||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
||||
name := c.GetFile(d)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
return &Chunker{}, nil
|
||||
}
|
||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Chunker{digest: d, size: size, f: f}, nil
|
||||
}
|
||||
|
||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
||||
// merging the data with the existing blob. It returns an error if any. As a
|
||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
||||
// io.ErrUnexpectedEOF.
|
||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
||||
if c.f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cw := &checkWriter{
|
||||
d: d,
|
||||
size: chunk.Size(),
|
||||
h: sha256.New(),
|
||||
f: c.f,
|
||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
||||
}
|
||||
|
||||
_, err := io.CopyN(cw, r, chunk.Size())
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (c *Chunker) Close() error {
|
||||
return c.f.Close()
|
||||
}
|
4
server/internal/cache/blob/digest.go
vendored
4
server/internal/cache/blob/digest.go
vendored
@@ -63,6 +63,10 @@ func (d Digest) Short() string {
|
||||
return fmt.Sprintf("%x", d.sum[:4])
|
||||
}
|
||||
|
||||
func (d Digest) Sum() [32]byte {
|
||||
return d.sum
|
||||
}
|
||||
|
||||
func (d Digest) Compare(other Digest) int {
|
||||
return slices.Compare(d.sum[:], other.sum[:])
|
||||
}
|
||||
|
@@ -1,78 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Chunk struct {
|
||||
Start, End int64
|
||||
}
|
||||
|
||||
func New(start, end int64) Chunk {
|
||||
return Chunk{start, end}
|
||||
}
|
||||
|
||||
// ParseRange parses a string in the form "unit=range" where unit is a string
|
||||
// and range is a string in the form "start-end". It returns the unit and the
|
||||
// range as a Chunk.
|
||||
func ParseRange(s string) (unit string, _ Chunk, _ error) {
|
||||
unit, r, _ := strings.Cut(s, "=")
|
||||
if r == "" {
|
||||
return unit, Chunk{}, nil
|
||||
}
|
||||
c, err := Parse(r)
|
||||
if err != nil {
|
||||
return "", Chunk{}, err
|
||||
}
|
||||
return unit, c, err
|
||||
}
|
||||
|
||||
// Parse parses a string in the form "start-end" and returns the Chunk.
|
||||
func Parse(s string) (Chunk, error) {
|
||||
startStr, endStr, _ := strings.Cut(s, "-")
|
||||
start, err := strconv.ParseInt(startStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid start: %v", err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid end: %v", err)
|
||||
}
|
||||
if start > end {
|
||||
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
|
||||
}
|
||||
return Chunk{start, end}, nil
|
||||
}
|
||||
|
||||
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
|
||||
// the range [0, size), in order.
|
||||
func Of(size, chunkSize int64) iter.Seq[Chunk] {
|
||||
return func(yield func(Chunk) bool) {
|
||||
for start := int64(0); start < size; start += chunkSize {
|
||||
end := min(start+chunkSize-1, size-1)
|
||||
if !yield(Chunk{start, end}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of Chunks of size chunkSize needed to cover the
|
||||
// range [0, size).
|
||||
func Count(size, chunkSize int64) int64 {
|
||||
return (size + chunkSize - 1) / chunkSize
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// String returns the string representation of the Chunk in the form
|
||||
// "{start}-{end}".
|
||||
func (c Chunk) String() string {
|
||||
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||
}
|
@@ -1,65 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOf(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want []Chunk
|
||||
}{
|
||||
{0, 1, nil},
|
||||
{1, 1, []Chunk{{0, 0}}},
|
||||
{1, 2, []Chunk{{0, 0}}},
|
||||
{2, 1, []Chunk{{0, 0}, {1, 1}}},
|
||||
{10, 9, []Chunk{{0, 8}, {9, 9}}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := slices.Collect(Of(tt.total, tt.chunkSize))
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
cases := []struct {
|
||||
c Chunk
|
||||
want int64
|
||||
}{
|
||||
{Chunk{0, 0}, 1},
|
||||
{Chunk{0, 1}, 2},
|
||||
{Chunk{3, 4}, 2},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.c.Size()
|
||||
if got != tt.want {
|
||||
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want int64
|
||||
}{
|
||||
{0, 1, 0},
|
||||
{1, 1, 1},
|
||||
{1, 2, 1},
|
||||
{2, 1, 2},
|
||||
{10, 9, 2},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
got := Count(tt.total, tt.chunkSize)
|
||||
if got != tt.want {
|
||||
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
@@ -19,11 +19,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,10 +37,8 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||
"github.com/ollama/ollama/server/internal/internal/names"
|
||||
"github.com/ollama/ollama/server/internal/internal/syncs"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
@@ -66,12 +66,7 @@ var (
|
||||
const (
|
||||
// DefaultChunkingThreshold is the threshold at which a layer should be
|
||||
// split up into chunks when downloading.
|
||||
DefaultChunkingThreshold = 128 << 20
|
||||
|
||||
// DefaultMaxChunkSize is the default maximum size of a chunk to
|
||||
// download. It is configured based on benchmarks and aims to strike a
|
||||
// balance between download speed and memory usage.
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
DefaultChunkingThreshold = 64 << 20
|
||||
)
|
||||
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
@@ -211,8 +206,7 @@ type Registry struct {
|
||||
// pushing or pulling models. If zero, the number of streams is
|
||||
// determined by [runtime.GOMAXPROCS].
|
||||
//
|
||||
// Clients that want "unlimited" streams should set this to a large
|
||||
// number.
|
||||
// A negative value means no limit.
|
||||
MaxStreams int
|
||||
|
||||
// ChunkingThreshold is the maximum size of a layer to download in a single
|
||||
@@ -266,6 +260,7 @@ func DefaultRegistry() (*Registry, error) {
|
||||
}
|
||||
|
||||
var rc Registry
|
||||
rc.UserAgent = UserAgent()
|
||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -281,25 +276,24 @@ func DefaultRegistry() (*Registry, error) {
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
func UserAgent() string {
|
||||
buildinfo, _ := debug.ReadBuildInfo()
|
||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||
buildinfo.Main.Version,
|
||||
runtime.GOARCH,
|
||||
runtime.GOOS,
|
||||
runtime.Version(),
|
||||
)
|
||||
}
|
||||
|
||||
// Large downloads require a writter stream, so ensure we have at least
|
||||
// two streams to avoid a deadlock.
|
||||
return max(n, 2)
|
||||
func (r *Registry) maxStreams() int {
|
||||
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
}
|
||||
|
||||
func (r *Registry) maxChunkingThreshold() int64 {
|
||||
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
||||
}
|
||||
|
||||
// chunkSizeFor returns the chunk size for a layer of the given size. If the
|
||||
// size is less than or equal to the max chunking threshold, the size is
|
||||
// returned; otherwise, the max chunk size is returned.
|
||||
func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
@@ -426,6 +420,21 @@ func canRetry(err error) bool {
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
// It always calls update with a nil error.
|
||||
type trackingReader struct {
|
||||
r io.Reader
|
||||
n *atomic.Int64
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
return
|
||||
}
|
||||
|
||||
// Pull pulls the model with the given name from the remote registry into the
|
||||
// cache.
|
||||
//
|
||||
@@ -434,11 +443,6 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.Resolve(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -457,126 +461,95 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err == nil && info.Size == l.Size
|
||||
}
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
|
||||
layers := m.Layers
|
||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||
layers = append(layers, m.Config)
|
||||
}
|
||||
|
||||
for _, l := range layers {
|
||||
// Send initial layer trace events to allow clients to have an
|
||||
// understanding of work to be done before work starts.
|
||||
t := traceFromContext(ctx)
|
||||
skip := make([]bool, len(layers))
|
||||
for i, l := range layers {
|
||||
t.update(l, 0, nil)
|
||||
if exists(l) {
|
||||
skip[i] = true
|
||||
t.update(l, l.Size, ErrCached)
|
||||
}
|
||||
}
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
for i, l := range layers {
|
||||
if skip[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
|
||||
req, err := r.newRequest(ctx, "GET", blobURL, nil)
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
t.update(l, 0, err)
|
||||
continue
|
||||
}
|
||||
defer chunked.Close()
|
||||
|
||||
t.update(l, 0, nil)
|
||||
|
||||
if l.Size <= r.maxChunkingThreshold() {
|
||||
g.Go(func() error {
|
||||
// TODO(bmizerany): retry/backoff like below in
|
||||
// the chunking case
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
err = c.Put(l.Digest, res.Body, l.Size)
|
||||
if err == nil {
|
||||
t.update(l, l.Size, nil)
|
||||
}
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
q := syncs.NewRelayReader()
|
||||
var progress atomic.Int64
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
t.update(l, progress.Load(), err)
|
||||
break
|
||||
}
|
||||
|
||||
g.Go(func() (err error) {
|
||||
defer func() { q.CloseWithError(err) }()
|
||||
return c.Put(l.Digest, q, l.Size)
|
||||
})
|
||||
defer func() { t.update(l, progress.Load(), err) }()
|
||||
|
||||
var progress atomic.Int64
|
||||
|
||||
// We want to avoid extra round trips per chunk due to
|
||||
// redirects from the registry to the blob store, so
|
||||
// fire an initial request to get the final URL and
|
||||
// then use that URL for the chunk requests.
|
||||
req.Header.Set("Range", "bytes=0-0")
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Body.Close()
|
||||
req = res.Request.WithContext(req.Context())
|
||||
|
||||
wp := writerPool{size: r.maxChunkSize()}
|
||||
|
||||
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
ticket := q.Take()
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
q.CloseWithError(err)
|
||||
}
|
||||
ticket.Close()
|
||||
t.update(l, progress.Load(), err)
|
||||
}()
|
||||
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := func() error {
|
||||
req := req.Clone(req.Context())
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
tw := wp.get()
|
||||
tw.Reset(ticket)
|
||||
defer wp.put(tw)
|
||||
|
||||
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||
if err != nil {
|
||||
return maybeUnexpectedEOF(err)
|
||||
}
|
||||
if err := tw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
total := progress.Add(chunk.Size())
|
||||
if total >= l.Size {
|
||||
q.Close()
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if !canRetry(err) {
|
||||
return err
|
||||
}
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err := func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Count bytes towards
|
||||
// progress, as they arrive, so
|
||||
// that our bytes piggyback
|
||||
// other chunk updates on
|
||||
// completion.
|
||||
//
|
||||
// This tactic is enough to
|
||||
// show "smooth" progress given
|
||||
// the current CLI client. In
|
||||
// the near future, the server
|
||||
// should report download rate
|
||||
// since it knows better than
|
||||
// a client that is measuring
|
||||
// rate based on wall-clock
|
||||
// time-since-last-update.
|
||||
body := &trackingReader{r: res.Body, n: &progress}
|
||||
|
||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if !canRetry(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -615,8 +588,6 @@ type Manifest struct {
|
||||
Config *Layer `json:"config"`
|
||||
}
|
||||
|
||||
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
// Layer returns the layer with the given
|
||||
// digest, or nil if not found.
|
||||
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
@@ -643,10 +614,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
|
||||
// last phase of the commit which expects it, but does nothing
|
||||
// with it. This will be fixed in a future release of
|
||||
// ollama.com.
|
||||
Config *Layer `json:"config"`
|
||||
Config Layer `json:"config"`
|
||||
}{
|
||||
M: M(m),
|
||||
Config: &Layer{Digest: emptyDigest},
|
||||
M: M(m),
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
@@ -736,6 +706,123 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type chunksum struct {
|
||||
URL string
|
||||
Chunk blob.Chunk
|
||||
Digest blob.Digest
|
||||
}
|
||||
|
||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
||||
return func(yield func(chunksum, error) bool) {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
yield(cs, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
}
|
||||
return
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
if !yield(cs, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) client() *http.Client {
|
||||
if r.HTTPClient != nil {
|
||||
return r.HTTPClient
|
||||
@@ -898,13 +985,6 @@ func checkData(url string) string {
|
||||
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
||||
}
|
||||
|
||||
func maybeUnexpectedEOF(err error) error {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
@@ -991,27 +1071,22 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
type writerPool struct {
|
||||
size int64 // set by the caller
|
||||
|
||||
mu sync.Mutex
|
||||
ws []*bufio.Writer
|
||||
}
|
||||
|
||||
func (p *writerPool) get() *bufio.Writer {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.ws) == 0 {
|
||||
return bufio.NewWriterSize(nil, int(p.size))
|
||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
||||
if !found {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
||||
}
|
||||
w := p.ws[len(p.ws)-1]
|
||||
p.ws = p.ws[:len(p.ws)-1]
|
||||
return w
|
||||
}
|
||||
|
||||
func (p *writerPool) put(w *bufio.Writer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
w.Reset(nil)
|
||||
p.ws = append(p.ws, w)
|
||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
||||
}
|
||||
if start > end {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
||||
}
|
||||
return blob.Chunk{Start: start, End: end}, nil
|
||||
}
|
||||
|
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
@@ -428,7 +427,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
@@ -531,54 +530,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
rng := r.Header.Get("Range")
|
||||
if rng == "" {
|
||||
http.Error(w, "missing range", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
io.WriteString(w, "remote"[c.Start:c.End+1])
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
|
||||
})
|
||||
|
||||
// Force chunking by setting the threshold to less than the size of the
|
||||
// layer.
|
||||
rc.ChunkingThreshold = 3
|
||||
rc.MaxChunkSize = 3
|
||||
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
if err != nil {
|
||||
t.Errorf("update %v %d %v", d, n, err)
|
||||
}
|
||||
reads = append(reads, n)
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("reads = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
|
@@ -1,11 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
|
||||
os.Exit(1)
|
||||
}
|
@@ -1,107 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func BenchmarkDownload(b *testing.B) {
|
||||
run := func(fileSize, chunkSize int64) {
|
||||
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
|
||||
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
|
||||
}
|
||||
|
||||
run(100<<20, 8<<20)
|
||||
run(100<<20, 16<<20)
|
||||
run(100<<20, 32<<20)
|
||||
run(100<<20, 64<<20)
|
||||
run(100<<20, 128<<20) // 1 chunk
|
||||
}
|
||||
|
||||
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
|
||||
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := c.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
|
||||
return err
|
||||
}
|
||||
|
||||
var sleepTime atomic.Int64
|
||||
|
||||
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
|
||||
client := &http.Client{
|
||||
Transport: func() http.RoundTripper {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DisableKeepAlives = true
|
||||
return tr
|
||||
}(),
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
// warm up the client
|
||||
run(context.Background(), client, chunks.New(0, 1<<20))
|
||||
|
||||
b.SetBytes(fileSize)
|
||||
b.ReportAllocs()
|
||||
|
||||
// Give our CDN a min to breathe between benchmarks.
|
||||
time.Sleep(time.Duration(sleepTime.Swap(3)))
|
||||
|
||||
for b.Loop() {
|
||||
g, ctx := errgroup.WithContext(b.Context())
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for chunk := range chunks.Of(fileSize, chunkSize) {
|
||||
g.Go(func() error { return run(ctx, client, chunk) })
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
|
||||
}
|
||||
|
||||
func benchmarkWrite(b *testing.B, chunkSize int) {
|
||||
b.ReportAllocs()
|
||||
|
||||
dir := b.TempDir()
|
||||
f, err := os.Create(filepath.Join(dir, "write-single"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data := make([]byte, chunkSize)
|
||||
b.SetBytes(int64(chunkSize))
|
||||
r := bytes.NewReader(data)
|
||||
for b.Loop() {
|
||||
r.Reset(data)
|
||||
_, err := io.Copy(f, r)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,6 +1,5 @@
|
||||
// Package registry provides an http.Handler for handling local Ollama API
|
||||
// requests for performing tasks related to the ollama.com model registry and
|
||||
// the local disk cache.
|
||||
// Package registry implements an http.Handler for handling local Ollama API
|
||||
// model management requests. See [Local] for details.
|
||||
package registry
|
||||
|
||||
import (
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,16 +18,11 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
// Local is an http.Handler for handling local Ollama API requests for
|
||||
// performing tasks related to the ollama.com model registry combined with the
|
||||
// local disk cache.
|
||||
// Local implements an http.Handler for handling local Ollama API model
|
||||
// management requests, such as pushing, pulling, and deleting models.
|
||||
//
|
||||
// It is not concern of Local, or this package, to handle model creation, which
|
||||
// proceeds any registry operations for models it produces.
|
||||
//
|
||||
// NOTE: The package built for dealing with model creation should use
|
||||
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||
// directly to the blob disk cache.
|
||||
// It can be arranged for all unknown requests to be passed through to a
|
||||
// fallback handler, if one is provided.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Logger *slog.Logger // required
|
||||
@@ -63,6 +58,7 @@ func (e serverError) Error() string {
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
@@ -175,8 +171,16 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
}
|
||||
|
||||
type params struct {
|
||||
DeprecatedName string `json:"name"` // Use [params.model]
|
||||
Model string `json:"model"` // Use [params.model]
|
||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
||||
// but is deprecated. New clients should use [Model] instead.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
DeprecatedName string `json:"name"`
|
||||
|
||||
// Model is the name of the model to push, pull, or delete.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
Model string `json:"model"`
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
@@ -189,9 +193,18 @@ type params struct {
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||
// progress updates.
|
||||
ProgressStream bool `json:"stream"`
|
||||
// Stream, if true, will make the server send progress updates in a
|
||||
// streaming of JSON objects. If false, the server will send a single
|
||||
// JSON object with the final status as "success", or an error object
|
||||
// if an error occurred.
|
||||
//
|
||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||
// defined to default to true if not present, so we need a way to check
|
||||
// if the client decisively it to false. So, we use a pointer to a
|
||||
// bool. Gross.
|
||||
//
|
||||
// Use [stream()] to get the correct value for this field.
|
||||
Stream *bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
@@ -199,6 +212,13 @@ func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (p params) stream() bool {
|
||||
if p.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *p.Stream
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
@@ -212,16 +232,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
return errModelNotFound
|
||||
}
|
||||
if s.Prune == nil {
|
||||
return nil
|
||||
if s.Prune != nil {
|
||||
return s.Prune()
|
||||
}
|
||||
return s.Prune()
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Status string `json:"status"`
|
||||
Status string `json:"status,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
@@ -237,6 +257,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
if !p.stream() {
|
||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
return errModelNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
return enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
}
|
||||
|
||||
maybeFlush := func() {
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
@@ -246,69 +277,67 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
defer maybeFlush()
|
||||
|
||||
var mu sync.Mutex
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
progress := make(map[*ollama.Layer]int64)
|
||||
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||
pushUpdate := func() {
|
||||
defer maybeFlush()
|
||||
|
||||
// TODO(bmizerany): coalesce these updates; writing per
|
||||
// update is expensive
|
||||
// TODO(bmizerany): This scales poorly with more layers due to
|
||||
// needing to flush out them all in one big update. We _could_
|
||||
// just flush on the changed ones, or just track the whole
|
||||
// download. Needs more thought. This is fine for now.
|
||||
mu.Lock()
|
||||
maps.Copy(progressCopy, progress)
|
||||
mu.Unlock()
|
||||
for l, n := range progress {
|
||||
enc.Encode(progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Status: "pulling",
|
||||
Total: l.Size,
|
||||
Completed: n,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
||||
start := sync.OnceFunc(func() {
|
||||
pushUpdate()
|
||||
t.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
start() // flush initial state
|
||||
}
|
||||
mu.Lock()
|
||||
progress[l] = n
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// TODO(bmizerany): continue to support non-streaming responses
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
}()
|
||||
|
||||
func() {
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
mu.Lock()
|
||||
maybeFlush()
|
||||
mu.Unlock()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
pushUpdate()
|
||||
case err := <-done:
|
||||
pushUpdate()
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
}
|
||||
|
||||
// These final updates are not strictly necessary, because they have
|
||||
// already happened at this point. Our pull handler code used to do
|
||||
// these steps after, not during, the pull, and they were slow, so we
|
||||
// wanted to provide feedback to users what was happening. For now, we
|
||||
// keep them to not jar users who are used to seeing them. We can phase
|
||||
// them out with a new and nicer UX later. One without progress bars
|
||||
// and digests that no one cares about.
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
|
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
@@ -160,7 +159,6 @@ var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fmt.Printf("%q\n", a.Comment)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -179,7 +177,7 @@ func TestServerPull(t *testing.T) {
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving file: %s", r.URL.Path)
|
||||
t.Logf("serving blob: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
@@ -188,7 +186,7 @@ func TestServerPull(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
t.Errorf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
@@ -197,35 +195,29 @@ func TestServerPull(t *testing.T) {
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! missing %q in body", want)
|
||||
t.Errorf("! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! unexpected %q in body", want)
|
||||
t.Errorf("! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
{"status":"verifying layers"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
@@ -240,19 +232,39 @@ func TestServerPull(t *testing.T) {
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: invalid or missing name: \"\""}
|
||||
|
||||
!verifying
|
||||
!writing
|
||||
!success
|
||||
`)
|
||||
|
||||
// Non-streaming pulls
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
||||
checkResponse(got, `
|
||||
{"status":"success"}
|
||||
!digest
|
||||
!total
|
||||
!completed
|
||||
`)
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
var fellback bool
|
||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fellback = true
|
||||
})
|
||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
if !fellback {
|
||||
t.Fatal("expected Fallback to be called")
|
||||
}
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
|
@@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var system []api.Message
|
||||
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
isGemma3 := checkGemma3ModelFamily(m)
|
||||
|
||||
var imageNumTokens int
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
|
||||
if isMllama && len(msgs[i].Images) > 1 {
|
||||
return "", nil, errTooManyImages
|
||||
}
|
||||
|
||||
@@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkGemma3ModelFamily(m *Model) bool {
|
||||
for _, arch := range m.Config.ModelFamilies {
|
||||
if arch == "gemma3" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
21
types/bfloat16/LICENSE
Normal file
21
types/bfloat16/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Tristan Rice
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
57
types/bfloat16/bfloat16.go
Normal file
57
types/bfloat16/bfloat16.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Vendored code from https://github.com/d4l3k/go-bfloat16
|
||||
// unsafe pointer replaced by "math"
|
||||
package bfloat16
|
||||
|
||||
import "math"
|
||||
|
||||
type BF16 uint16
|
||||
|
||||
func FromBytes(buf []byte) BF16 {
|
||||
return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
|
||||
}
|
||||
|
||||
func ToBytes(b BF16) []byte {
|
||||
return []byte{byte(b & 0xFF), byte(b >> 8)}
|
||||
}
|
||||
|
||||
func Decode(buf []byte) []BF16 {
|
||||
var out []BF16
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, FromBytes(buf[i:]))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Encode(f []BF16) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(a)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func DecodeFloat32(buf []byte) []float32 {
|
||||
var out []float32
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, ToFloat32(FromBytes(buf[i:])))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func EncodeFloat32(f []float32) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(FromFloat32(a))...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFloat32(b BF16) float32 {
|
||||
u32 := uint32(b) << 16
|
||||
return math.Float32frombits(u32)
|
||||
}
|
||||
|
||||
func FromFloat32(f float32) BF16 {
|
||||
u32 := math.Float32bits(f)
|
||||
return BF16(u32 >> 16)
|
||||
}
|
53
types/bfloat16/bfloat16_test.go
Normal file
53
types/bfloat16/bfloat16_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package bfloat16
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func randomBytes(n int) []byte {
|
||||
out := make([]byte, n)
|
||||
if _, err := rand.Read(out); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := Decode(b)
|
||||
out := Encode(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeFloat32(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := DecodeFloat32(b)
|
||||
out := EncodeFloat32(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicFloat32(t *testing.T) {
|
||||
var in float32 = 1.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if !reflect.DeepEqual(in, out) {
|
||||
t.Fatalf("%+v != %+v", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexFloat32(t *testing.T) {
|
||||
var in float32 = 123456789123456789.123456789
|
||||
var want float32 = 123286039799267328.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if in == out {
|
||||
t.Fatalf("no loss of precision")
|
||||
}
|
||||
if out != want {
|
||||
t.Fatalf("%.16f != %.16f", want, out)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user