Compare commits
1 Commits
v0.6.2
...
pdevine/lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73a1e99f8a |
@@ -56,7 +56,7 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -392,8 +392,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
|
||||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
|
|||||||
129
cmd/cmd_test.go
129
cmd/cmd_test.go
@@ -757,132 +757,3 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCreateRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
from string
|
|
||||||
opts runOptions
|
|
||||||
expected *api.CreateRequest
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"basic test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "",
|
|
||||||
Prompt: "You are a fun AI agent",
|
|
||||||
Messages: []api.Message{},
|
|
||||||
WordWrap: true,
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "mymodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"parent model test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "parentmodel",
|
|
||||||
Messages: []api.Message{},
|
|
||||||
WordWrap: true,
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "parentmodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"parent model as filepath test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "/some/file/like/etc/passwd",
|
|
||||||
Messages: []api.Message{},
|
|
||||||
WordWrap: true,
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "mymodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"parent model as windows filepath test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
|
||||||
Messages: []api.Message{},
|
|
||||||
WordWrap: true,
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "mymodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"options test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "parentmodel",
|
|
||||||
Options: map[string]any{
|
|
||||||
"temperature": 1.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "parentmodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
Parameters: map[string]any{
|
|
||||||
"temperature": 1.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"messages test",
|
|
||||||
"newmodel",
|
|
||||||
runOptions{
|
|
||||||
Model: "mymodel",
|
|
||||||
ParentModel: "parentmodel",
|
|
||||||
System: "You are a fun AI agent",
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hello there!",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "hello to you!",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
WordWrap: true,
|
|
||||||
},
|
|
||||||
&api.CreateRequest{
|
|
||||||
From: "parentmodel",
|
|
||||||
Model: "newmodel",
|
|
||||||
System: "You are a fun AI agent",
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hello there!",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "hello to you!",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
actual := NewCreateRequest(tt.from, tt.opts)
|
|
||||||
if !cmp.Equal(actual, tt.expected) {
|
|
||||||
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
@@ -460,16 +459,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||||
parentModel := opts.ParentModel
|
|
||||||
|
|
||||||
modelName := model.ParseName(parentModel)
|
|
||||||
if !modelName.IsValid() {
|
|
||||||
parentModel = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Model: name,
|
Name: name,
|
||||||
From: cmp.Or(parentModel, opts.Model),
|
From: cmp.Or(opts.ParentModel, opts.Model),
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.System != "" {
|
if opts.System != "" {
|
||||||
|
|||||||
@@ -187,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
|||||||
|
|
||||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||||
|
|
||||||
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
|
||||||
|
|
||||||
```
|
|
||||||
# Allow all Chrome, Firefox, and Safari extensions
|
|
||||||
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
|
|||||||
@@ -583,52 +583,39 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||||
if llm.KV().Uint("vision.block_count") == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, layer := range llm.Tensors().GroupLayers() {
|
|
||||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
|
||||||
for _, tensor := range layer {
|
|
||||||
weights += tensor.Size()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
|
||||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
|
||||||
if patchSize == 0 {
|
|
||||||
slog.Warn("unknown patch size for vision model")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
|
||||||
|
|
||||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
|
||||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
|
||||||
numPatches++
|
|
||||||
}
|
|
||||||
|
|
||||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
|
||||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch llm.KV().Architecture() {
|
||||||
case "mllama":
|
case "mllama":
|
||||||
|
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||||
|
weights += layer.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := func(n string) uint64 {
|
||||||
|
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||||
|
return uint64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
imageSize := kv("image_size")
|
||||||
|
|
||||||
|
maxNumTiles := kv("max_num_tiles")
|
||||||
|
embeddingLength := kv("embedding_length")
|
||||||
|
headCount := kv("attention.head_count")
|
||||||
|
|
||||||
|
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||||
|
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||||
|
numPatches++
|
||||||
|
}
|
||||||
|
|
||||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||||
|
|
||||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
|
||||||
|
|
||||||
graphSize = 4 * (8 +
|
graphSize = 4 * (8 +
|
||||||
imageSize*imageSize*numChannels*maxNumTiles +
|
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||||
embeddingLength*numPatches*maxNumTiles +
|
embeddingLength*numPatches*maxNumTiles +
|
||||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||||
case "gemma3":
|
|
||||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
|
||||||
embeddingLength*patchSize +
|
|
||||||
numPatches*numPatches*headCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return weights, graphSize
|
return weights, graphSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -66,35 +66,6 @@ func TestIntegrationMllama(t *testing.T) {
|
|||||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationSplitBatch(t *testing.T) {
|
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req := api.GenerateRequest{
|
|
||||||
Model: "gemma3:4b",
|
|
||||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
|
||||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
|
||||||
Prompt: "what does the text in this image say?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]interface{}{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
Images: []api.ImageData{
|
|
||||||
image,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
|
||||||
resp := "the ollam"
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
|
||||||
defer cleanup()
|
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
|
||||||
// llava models on CPU can be quite slow to start,
|
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||||
|
|||||||
19
llama/llama.cpp/src/llama-arch.cpp
vendored
19
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -37,7 +37,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||||
{ LLM_ARCH_GEMMA, "gemma" },
|
{ LLM_ARCH_GEMMA, "gemma" },
|
||||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
|
||||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
{ LLM_ARCH_MAMBA, "mamba" },
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
{ LLM_ARCH_XVERSE, "xverse" },
|
||||||
@@ -805,24 +804,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
LLM_ARCH_GEMMA3,
|
|
||||||
{
|
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
||||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
||||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
||||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
|
||||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
|
||||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
||||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
||||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
|
||||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
||||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
||||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_STARCODER2,
|
LLM_ARCH_STARCODER2,
|
||||||
{
|
{
|
||||||
|
|||||||
1
llama/llama.cpp/src/llama-arch.h
vendored
1
llama/llama.cpp/src/llama-arch.h
vendored
@@ -41,7 +41,6 @@ enum llm_arch {
|
|||||||
LLM_ARCH_MINICPM3,
|
LLM_ARCH_MINICPM3,
|
||||||
LLM_ARCH_GEMMA,
|
LLM_ARCH_GEMMA,
|
||||||
LLM_ARCH_GEMMA2,
|
LLM_ARCH_GEMMA2,
|
||||||
LLM_ARCH_GEMMA3,
|
|
||||||
LLM_ARCH_STARCODER2,
|
LLM_ARCH_STARCODER2,
|
||||||
LLM_ARCH_MAMBA,
|
LLM_ARCH_MAMBA,
|
||||||
LLM_ARCH_XVERSE,
|
LLM_ARCH_XVERSE,
|
||||||
|
|||||||
7
llama/llama.cpp/src/llama-model.cpp
vendored
7
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -878,9 +878,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA3:
|
|
||||||
{
|
|
||||||
} break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
case LLM_ARCH_STARCODER2:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@@ -2540,9 +2537,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA3:
|
|
||||||
{
|
|
||||||
} break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
case LLM_ARCH_STARCODER2:
|
||||||
{
|
{
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
@@ -4035,7 +4029,6 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
|||||||
case LLM_ARCH_PHIMOE:
|
case LLM_ARCH_PHIMOE:
|
||||||
case LLM_ARCH_GEMMA:
|
case LLM_ARCH_GEMMA:
|
||||||
case LLM_ARCH_GEMMA2:
|
case LLM_ARCH_GEMMA2:
|
||||||
case LLM_ARCH_GEMMA3:
|
|
||||||
case LLM_ARCH_STARCODER2:
|
case LLM_ARCH_STARCODER2:
|
||||||
case LLM_ARCH_OPENELM:
|
case LLM_ARCH_OPENELM:
|
||||||
case LLM_ARCH_GPTNEOX:
|
case LLM_ARCH_GPTNEOX:
|
||||||
|
|||||||
9
llama/llama.cpp/src/llama-quant.cpp
vendored
9
llama/llama.cpp/src/llama-quant.cpp
vendored
@@ -737,15 +737,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||||
|
|
||||||
// don't quantize vision stuff
|
|
||||||
quantize &= name.find("v.blk.") == std::string::npos;
|
|
||||||
|
|
||||||
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
|
||||||
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
|
||||||
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
|
||||||
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
|
||||||
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
|
||||||
|
|
||||||
// quantize only 2D and 3D tensors (experts)
|
// quantize only 2D and 3D tensors (experts)
|
||||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Patrick Devine <patrick@infrahq.com>
|
|
||||||
Date: Fri, 14 Mar 2025 16:33:23 -0700
|
|
||||||
Subject: [PATCH] gemma3 quantization
|
|
||||||
|
|
||||||
---
|
|
||||||
src/llama-arch.cpp | 19 +++++++++++++++++++
|
|
||||||
src/llama-arch.h | 1 +
|
|
||||||
src/llama-model.cpp | 7 +++++++
|
|
||||||
src/llama-quant.cpp | 9 +++++++++
|
|
||||||
4 files changed, 36 insertions(+)
|
|
||||||
|
|
||||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
|
||||||
index b6f20286..b443fcd3 100644
|
|
||||||
--- a/src/llama-arch.cpp
|
|
||||||
+++ b/src/llama-arch.cpp
|
|
||||||
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
||||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
|
||||||
{ LLM_ARCH_GEMMA, "gemma" },
|
|
||||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
|
||||||
+ { LLM_ARCH_GEMMA3, "gemma3" },
|
|
||||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
|
||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
|
||||||
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
||||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
+ {
|
|
||||||
+ LLM_ARCH_GEMMA3,
|
|
||||||
+ {
|
|
||||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
||||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
||||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
||||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
|
||||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
|
||||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
||||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
||||||
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
|
||||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
||||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
||||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
||||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
||||||
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
|
||||||
+ },
|
|
||||||
+ },
|
|
||||||
{
|
|
||||||
LLM_ARCH_STARCODER2,
|
|
||||||
{
|
|
||||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
|
||||||
index ec742224..aad92a5d 100644
|
|
||||||
--- a/src/llama-arch.h
|
|
||||||
+++ b/src/llama-arch.h
|
|
||||||
@@ -41,6 +41,7 @@ enum llm_arch {
|
|
||||||
LLM_ARCH_MINICPM3,
|
|
||||||
LLM_ARCH_GEMMA,
|
|
||||||
LLM_ARCH_GEMMA2,
|
|
||||||
+ LLM_ARCH_GEMMA3,
|
|
||||||
LLM_ARCH_STARCODER2,
|
|
||||||
LLM_ARCH_MAMBA,
|
|
||||||
LLM_ARCH_XVERSE,
|
|
||||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
|
||||||
index ab1a07d1..70183041 100644
|
|
||||||
--- a/src/llama-model.cpp
|
|
||||||
+++ b/src/llama-model.cpp
|
|
||||||
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
+ case LLM_ARCH_GEMMA3:
|
|
||||||
+ {
|
|
||||||
+ } break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
{
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
||||||
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
||||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
+ case LLM_ARCH_GEMMA3:
|
|
||||||
+ {
|
|
||||||
+ } break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
{
|
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
||||||
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
|
||||||
case LLM_ARCH_PHIMOE:
|
|
||||||
case LLM_ARCH_GEMMA:
|
|
||||||
case LLM_ARCH_GEMMA2:
|
|
||||||
+ case LLM_ARCH_GEMMA3:
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
case LLM_ARCH_OPENELM:
|
|
||||||
case LLM_ARCH_GPTNEOX:
|
|
||||||
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
|
|
||||||
index 6eb1da08..d2f3a510 100644
|
|
||||||
--- a/src/llama-quant.cpp
|
|
||||||
+++ b/src/llama-quant.cpp
|
|
||||||
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
||||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
|
||||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
|
||||||
|
|
||||||
+ // don't quantize vision stuff
|
|
||||||
+ quantize &= name.find("v.blk.") == std::string::npos;
|
|
||||||
+
|
|
||||||
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
|
|
||||||
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
|
|
||||||
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
|
|
||||||
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
|
|
||||||
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
|
|
||||||
+
|
|
||||||
// quantize only 2D and 3D tensors (experts)
|
|
||||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
|
||||||
|
|
||||||
@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
layerSize = blk.Size()
|
layerSize = blk.Size()
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv / f.KV().BlockCount()
|
||||||
memoryWeights += blk.Size()
|
|
||||||
}
|
}
|
||||||
|
memoryWeights += layerSize
|
||||||
|
|
||||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||||
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
|||||||
// memory of the weights
|
// memory of the weights
|
||||||
"total", format.HumanBytes2(m.memoryWeights),
|
"total", format.HumanBytes2(m.memoryWeights),
|
||||||
// memory of repeating layers
|
// memory of repeating layers
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
|
||||||
// memory of non-repeating layers
|
// memory of non-repeating layers
|
||||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||||
),
|
),
|
||||||
|
|||||||
136
llm/server.go
136
llm/server.go
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("starting llama server", "cmd", s.cmd)
|
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||||
if envconfig.Debug() {
|
if envconfig.Debug() {
|
||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
|
|||||||
ServerStatusError
|
ServerStatusError
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s ServerStatus) String() string {
|
func (s ServerStatus) ToString() string {
|
||||||
switch s {
|
switch s {
|
||||||
case ServerStatusReady:
|
case ServerStatusReady:
|
||||||
return "llm server ready"
|
return "llm server ready"
|
||||||
@@ -485,9 +485,12 @@ func (s ServerStatus) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerStatusResponse struct {
|
type ServerStatusResp struct {
|
||||||
Status ServerStatus `json:"status"`
|
Status string `json:"status"`
|
||||||
Progress float32 `json:"progress"`
|
SlotsIdle int `json:"slots_idle"`
|
||||||
|
SlotsProcessing int `json:"slots_processing"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
@@ -499,7 +502,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|||||||
}
|
}
|
||||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
|
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
||||||
}
|
}
|
||||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||||
}
|
}
|
||||||
@@ -524,19 +527,21 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|||||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ssr ServerStatusResponse
|
var status ServerStatusResp
|
||||||
if err := json.Unmarshal(body, &ssr); err != nil {
|
if err := json.Unmarshal(body, &status); err != nil {
|
||||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ssr.Status {
|
switch status.Status {
|
||||||
case ServerStatusLoadingModel:
|
case "ok":
|
||||||
s.loadProgress = ssr.Progress
|
return ServerStatusReady, nil
|
||||||
return ssr.Status, nil
|
case "no slot available":
|
||||||
case ServerStatusReady, ServerStatusNoSlotsAvailable:
|
return ServerStatusNoSlotsAvailable, nil
|
||||||
return ssr.Status, nil
|
case "loading model":
|
||||||
|
s.loadProgress = status.Progress
|
||||||
|
return ServerStatusLoadingModel, nil
|
||||||
default:
|
default:
|
||||||
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
|
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -611,7 +616,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
status, _ := s.getServerStatus(ctx)
|
status, _ := s.getServerStatus(ctx)
|
||||||
if lastStatus != status && status != ServerStatusReady {
|
if lastStatus != status && status != ServerStatusReady {
|
||||||
// Only log on status changes
|
// Only log on status changes
|
||||||
slog.Info("waiting for server to become available", "status", status)
|
slog.Info("waiting for server to become available", "status", status.ToString())
|
||||||
}
|
}
|
||||||
switch status {
|
switch status {
|
||||||
case ServerStatusReady:
|
case ServerStatusReady:
|
||||||
@@ -625,7 +630,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||||
stallTimer = time.Now().Add(stallDuration)
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||||
slog.Debug("model load completed, waiting for server to become available", "status", status)
|
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
||||||
stallTimer = time.Now().Add(stallDuration)
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
fullyLoaded = true
|
fullyLoaded = true
|
||||||
}
|
}
|
||||||
@@ -666,26 +671,63 @@ type ImageData struct {
|
|||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type completion struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit"`
|
||||||
|
|
||||||
|
Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string
|
||||||
DoneReason string `json:"done_reason"`
|
DoneReason string
|
||||||
Done bool `json:"done"`
|
Done bool
|
||||||
PromptEvalCount int `json:"prompt_eval_count"`
|
PromptEvalCount int
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
|
request := map[string]any{
|
||||||
|
"prompt": req.Prompt,
|
||||||
|
"stream": true,
|
||||||
|
"n_predict": req.Options.NumPredict,
|
||||||
|
"n_keep": req.Options.NumKeep,
|
||||||
|
"main_gpu": req.Options.MainGPU,
|
||||||
|
"temperature": req.Options.Temperature,
|
||||||
|
"top_k": req.Options.TopK,
|
||||||
|
"top_p": req.Options.TopP,
|
||||||
|
"min_p": req.Options.MinP,
|
||||||
|
"typical_p": req.Options.TypicalP,
|
||||||
|
"repeat_last_n": req.Options.RepeatLastN,
|
||||||
|
"repeat_penalty": req.Options.RepeatPenalty,
|
||||||
|
"presence_penalty": req.Options.PresencePenalty,
|
||||||
|
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||||
|
"mirostat": req.Options.Mirostat,
|
||||||
|
"mirostat_tau": req.Options.MirostatTau,
|
||||||
|
"mirostat_eta": req.Options.MirostatEta,
|
||||||
|
"seed": req.Options.Seed,
|
||||||
|
"stop": req.Options.Stop,
|
||||||
|
"image_data": req.Images,
|
||||||
|
"cache_prompt": true,
|
||||||
|
}
|
||||||
|
|
||||||
if len(req.Format) > 0 {
|
if len(req.Format) > 0 {
|
||||||
switch string(req.Format) {
|
switch string(req.Format) {
|
||||||
case `null`, `""`:
|
case `null`, `""`:
|
||||||
@@ -693,7 +735,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
// these as "not set".
|
// these as "not set".
|
||||||
break
|
break
|
||||||
case `"json"`:
|
case `"json"`:
|
||||||
req.Grammar = grammarJSON
|
request["grammar"] = grammarJSON
|
||||||
default:
|
default:
|
||||||
if req.Format[0] != '{' {
|
if req.Format[0] != '{' {
|
||||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||||
@@ -704,15 +746,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
if g == nil {
|
if g == nil {
|
||||||
return fmt.Errorf("invalid JSON schema in format")
|
return fmt.Errorf("invalid JSON schema in format")
|
||||||
}
|
}
|
||||||
req.Grammar = string(g)
|
request["grammar"] = string(g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Options == nil {
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
req.Options = &opts
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
@@ -733,7 +770,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return fmt.Errorf("unexpected server status: %s", status)
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handling JSON marshaling with special characters unescaped.
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
@@ -741,7 +778,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
enc := json.NewEncoder(buffer)
|
enc := json.NewEncoder(buffer)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
if err := enc.Encode(req); err != nil {
|
if err := enc.Encode(request); err != nil {
|
||||||
return fmt.Errorf("failed to marshal data: %v", err)
|
return fmt.Errorf("failed to marshal data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -792,7 +829,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
evt = line
|
evt = line
|
||||||
}
|
}
|
||||||
|
|
||||||
var c CompletionResponse
|
var c completion
|
||||||
if err := json.Unmarshal(evt, &c); err != nil {
|
if err := json.Unmarshal(evt, &c); err != nil {
|
||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -816,8 +853,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Done {
|
if c.Stop {
|
||||||
fn(c)
|
doneReason := "stop"
|
||||||
|
if c.StoppedLimit {
|
||||||
|
doneReason = "length"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: doneReason,
|
||||||
|
PromptEvalCount: c.Timings.PromptN,
|
||||||
|
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||||
|
EvalCount: c.Timings.PredictedN,
|
||||||
|
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -865,7 +914,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||||
@@ -1010,3 +1059,12 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
|||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseDurationMs(ms float64) time.Duration {
|
||||||
|
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dur
|
||||||
|
}
|
||||||
|
|||||||
40
logging/log.go
Normal file
40
logging/log.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const LevelTrace slog.Level = slog.LevelDebug - 4
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger() *Logger {
|
||||||
|
handler := slog.NewTextHandler(os.Stdout, nil)
|
||||||
|
return &Logger{
|
||||||
|
logger: slog.New(handler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Trace(msg string, args ...any) {
|
||||||
|
l.logger.Log(context.Background(), LevelTrace, msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(msg string, args ...any) {
|
||||||
|
l.logger.Debug(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(msg string, args ...any) {
|
||||||
|
l.logger.Info(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(msg string, args ...any) {
|
||||||
|
l.logger.Warn(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(msg string, args ...any) {
|
||||||
|
l.logger.Error(msg, args...)
|
||||||
|
}
|
||||||
@@ -312,19 +312,17 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := C.malloc(C.size_t(t.Size()))
|
bts := make([]byte, t.Size())
|
||||||
if bts == nil {
|
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||||
return errors.New("failed to allocate tensor buffer")
|
if err != nil {
|
||||||
}
|
return err
|
||||||
defer C.free(bts)
|
|
||||||
|
|
||||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
|
||||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
|
||||||
if err != nil || n != len(buf) {
|
|
||||||
return errors.New("read failed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
if n != len(bts) {
|
||||||
|
return errors.New("short read")
|
||||||
|
}
|
||||||
|
|
||||||
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -373,7 +371,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||||
C.int(len(schedBackends)),
|
C.int(len(schedBackends)),
|
||||||
C.size_t(maxGraphNodes),
|
C.size_t(maxGraphNodes),
|
||||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
true,
|
||||||
),
|
),
|
||||||
input: deviceBufferTypes[input.d],
|
input: deviceBufferTypes[input.d],
|
||||||
output: deviceBufferTypes[output.d],
|
output: deviceBufferTypes[output.d],
|
||||||
|
|||||||
@@ -15,12 +15,6 @@ type Input struct {
|
|||||||
// stored in Multimodal, used for caching and comparing
|
// stored in Multimodal, used for caching and comparing
|
||||||
// equality.
|
// equality.
|
||||||
MultimodalHash uint64
|
MultimodalHash uint64
|
||||||
|
|
||||||
// SameBatch forces the following number of tokens to be processed
|
|
||||||
// in a single batch, breaking and extending batches as needed.
|
|
||||||
// Useful for things like images that must be processed in one
|
|
||||||
// shot.
|
|
||||||
SameBatch int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultimodalIndex is a multimodal element (such as an image)
|
// MultimodalIndex is a multimodal element (such as an image)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
|
|||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
// represents the contents.
|
// represents the contents.
|
||||||
PostTokenize([]input.Input) ([]input.Input, error)
|
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package gemma3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"hash/fnv"
|
||||||
"image"
|
"image"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@@ -111,23 +112,36 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return visionOutputs, nil
|
return visionOutputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
type imageToken struct {
|
||||||
|
embedding ml.Tensor
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||||
var result []input.Input
|
var result []input.Input
|
||||||
|
fnvHash := fnv.New64a()
|
||||||
|
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if inp.Multimodal == nil {
|
if inp.Multimodal == nil {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
} else {
|
} else {
|
||||||
|
imageInputs := []input.Input{
|
||||||
|
{Token: 108}, // "\n\n"
|
||||||
|
{Token: 255999}, // "<start_of_image>""
|
||||||
|
}
|
||||||
|
result = append(result, imageInputs...)
|
||||||
|
|
||||||
|
// add image embeddings
|
||||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||||
|
|
||||||
result = append(result,
|
for i := range inputMultimodal.Dim(1) {
|
||||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
fnvHash.Reset()
|
||||||
input.Input{Token: 255999}, // "<start_of_image>""
|
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
||||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
fnvHash.Write([]byte{byte(i)})
|
||||||
)
|
|
||||||
|
|
||||||
// add image token placeholders
|
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
||||||
|
}
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 256000}, // <end_of_image>
|
input.Input{Token: 256000}, // <end_of_image>
|
||||||
|
|||||||
@@ -171,20 +171,53 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||||
|
var embedding ml.Tensor
|
||||||
|
var src, dst, length int
|
||||||
|
var except []int
|
||||||
|
|
||||||
|
for _, image := range multimodal {
|
||||||
|
imageToken := image.Multimodal.(imageToken)
|
||||||
|
imageSrc := imageToken.index
|
||||||
|
imageDst := image.Index
|
||||||
|
|
||||||
|
if embedding == nil {
|
||||||
|
embedding = imageToken.embedding
|
||||||
|
src = imageSrc
|
||||||
|
dst = imageDst
|
||||||
|
length = 1
|
||||||
|
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
||||||
|
src = imageSrc
|
||||||
|
dst = imageDst
|
||||||
|
length++
|
||||||
|
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
||||||
|
length++
|
||||||
|
} else {
|
||||||
|
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||||
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||||
|
|
||||||
|
embedding = imageToken.embedding
|
||||||
|
src = imageSrc
|
||||||
|
dst = imageDst
|
||||||
|
length = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
except = append(except, imageDst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedding != nil {
|
||||||
|
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||||
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||||
|
}
|
||||||
|
|
||||||
|
return except
|
||||||
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||||
var except []int
|
|
||||||
for _, image := range opts.Multimodal {
|
|
||||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
|
||||||
|
|
||||||
for i := range visionOutputs.Dim(1) {
|
|
||||||
except = append(except, image.Index+i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
// gemma alternates between the sliding window (local) and causal (global)
|
// gemma alternates between the sliding window (local) and causal (global)
|
||||||
|
|||||||
@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||||
var images []input.Input
|
var images []input.Input
|
||||||
fnvHash := fnv.New64a()
|
fnvHash := fnv.New64a()
|
||||||
|
|
||||||
for i := range inputs {
|
for i := range inputs {
|
||||||
if inputs[i].Multimodal == nil {
|
if inputs[i].Multimodal == nil {
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
inputs[i].Multimodal = images[0].Multimodal
|
||||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||||
for j := 1; j < len(images); j++ {
|
for j := 1; j < len(images); j++ {
|
||||||
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||||
fnvHash.Reset()
|
fnvHash.Reset()
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||||
@@ -138,10 +138,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(opts.Multimodal) > 0 {
|
||||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
||||||
if len(images) > 0 {
|
|
||||||
crossAttentionStates = images[len(images)-1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
|
|||||||
@@ -2,15 +2,18 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
|
var log = logging.NewLogger()
|
||||||
|
|
||||||
func replaceWhitespaceBySeperator(s string) string {
|
func replaceWhitespaceBySeperator(s string) string {
|
||||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||||
}
|
}
|
||||||
@@ -24,7 +27,7 @@ type SentencePieceModel struct {
|
|||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||||
|
|
||||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
var maxTokenLen int
|
var maxTokenLen int
|
||||||
@@ -38,7 +41,7 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
"max token len", maxTokenLen)
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
@@ -91,7 +94,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("fragments", "frags", fragments)
|
log.Trace("fragments", "frags", fragments)
|
||||||
|
|
||||||
var ids []int32
|
var ids []int32
|
||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
@@ -129,7 +132,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("tokenizer", "merges", merges)
|
log.Trace("tokenizer", "merges", merges)
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
pairwise := func(a, b int) *candidate {
|
||||||
if a < 0 || b >= len(runes) {
|
if a < 0 || b >= len(runes) {
|
||||||
@@ -156,7 +159,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
pqv := pq.Values()
|
pqv := pq.Values()
|
||||||
for _, v := range pqv {
|
for _, v := range pqv {
|
||||||
e := v.(*candidate)
|
e := v.(*candidate)
|
||||||
slog.Debug("candidate", "candidate", e)
|
log.Trace("candidate", "candidate", e)
|
||||||
}
|
}
|
||||||
|
|
||||||
for !pq.Empty() {
|
for !pq.Empty() {
|
||||||
@@ -164,7 +167,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
pair := v.(*candidate)
|
pair := v.(*candidate)
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
slog.Debug("pair", "left", left, "right", right)
|
log.Trace("pair", "left", left, "right", right)
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -189,14 +192,14 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("merges", "merges", merges)
|
log.Trace("merges", "merges", merges)
|
||||||
|
|
||||||
for _, merge := range merges {
|
for _, merge := range merges {
|
||||||
if len(merge.runes) > 0 {
|
if len(merge.runes) > 0 {
|
||||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("missing token", "token", string(merge.runes))
|
log.Error("missing token", "token", string(merge.runes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -206,19 +209,19 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
if spm.vocab.AddBOS {
|
if spm.vocab.AddBOS {
|
||||||
if ids[0] == spm.vocab.BOS {
|
if ids[0] == spm.vocab.BOS {
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
||||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
ids = append([]int32{spm.vocab.BOS}, ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if spm.vocab.AddEOS {
|
if spm.vocab.AddEOS {
|
||||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
if ids[len(ids)-1] == spm.vocab.EOS {
|
||||||
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
||||||
ids = append(ids, spm.vocab.EOS)
|
ids = append(ids, spm.vocab.EOS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -241,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
log.Debug("decoded", "ids", ids, "text", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,7 +99,7 @@ type NewSequenceParams struct {
|
|||||||
embedding bool
|
embedding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -164,7 +163,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// generating image embeddings for each image
|
// generating image embeddings for each image
|
||||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||||
var inputs []input
|
var inputs []input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
@@ -230,7 +229,7 @@ type Server struct {
|
|||||||
image *ImageContext
|
image *ImageContext
|
||||||
|
|
||||||
// status for external health reporting - loading, ready to serve, etc.
|
// status for external health reporting - loading, ready to serve, etc.
|
||||||
status llm.ServerStatus
|
status ServerStatus
|
||||||
|
|
||||||
// current progress on loading the model
|
// current progress on loading the model
|
||||||
progress float32
|
progress float32
|
||||||
@@ -542,18 +541,75 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||||
|
// this way the api acts as a proxy instead of using a different api for the
|
||||||
|
// runner
|
||||||
|
type Options struct {
|
||||||
|
api.Runner
|
||||||
|
|
||||||
|
NumKeep int `json:"n_keep"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NumPredict int `json:"n_predict"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
MinP float32 `json:"min_p"`
|
||||||
|
TypicalP float32 `json:"typical_p"`
|
||||||
|
RepeatLastN int `json:"repeat_last_n"`
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
|
Mirostat int `json:"mirostat"`
|
||||||
|
MirostatTau float32 `json:"mirostat_tau"`
|
||||||
|
MirostatEta float32 `json:"mirostat_eta"`
|
||||||
|
Stop []string `json:"stop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
Data []byte `json:"data"`
|
||||||
|
ID int `json:"id"`
|
||||||
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Images []ImageData `json:"image_data"`
|
||||||
|
Grammar string `json:"grammar"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
type Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
|
PromptN int `json:"prompt_n,omitempty"`
|
||||||
|
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||||
|
|
||||||
|
Timings Timings `json:"timings"`
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var req llm.CompletionRequest
|
var req CompletionRequest
|
||||||
|
req.Options = Options(api.DefaultOptions())
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Options == nil {
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
req.Options = &opts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
// Set the headers to indicate streaming
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
@@ -564,28 +620,26 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract options from the CompletionRequest
|
var samplingParams llama.SamplingParams
|
||||||
samplingParams := llama.SamplingParams{
|
samplingParams.TopK = req.TopK
|
||||||
TopK: req.Options.TopK,
|
samplingParams.TopP = req.TopP
|
||||||
TopP: req.Options.TopP,
|
samplingParams.MinP = req.MinP
|
||||||
MinP: req.Options.MinP,
|
samplingParams.TypicalP = req.TypicalP
|
||||||
TypicalP: req.Options.TypicalP,
|
samplingParams.Temp = req.Temperature
|
||||||
Temp: req.Options.Temperature,
|
samplingParams.RepeatLastN = req.RepeatLastN
|
||||||
RepeatLastN: req.Options.RepeatLastN,
|
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||||
PenaltyRepeat: req.Options.RepeatPenalty,
|
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||||
PenaltyFreq: req.Options.FrequencyPenalty,
|
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||||
PenaltyPresent: req.Options.PresencePenalty,
|
samplingParams.Mirostat = req.Mirostat
|
||||||
Mirostat: req.Options.Mirostat,
|
samplingParams.MirostatTau = req.MirostatTau
|
||||||
MirostatTau: req.Options.MirostatTau,
|
samplingParams.MirostatEta = req.MirostatEta
|
||||||
MirostatEta: req.Options.MirostatEta,
|
samplingParams.Seed = uint32(req.Seed)
|
||||||
Seed: uint32(req.Options.Seed),
|
samplingParams.Grammar = req.Grammar
|
||||||
Grammar: req.Grammar,
|
|
||||||
}
|
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.Options.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Options.Stop,
|
stop: req.Stop,
|
||||||
numKeep: req.Options.NumKeep,
|
numKeep: req.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
@@ -608,7 +662,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
@@ -637,7 +691,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
@@ -648,17 +702,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
// Send the final response
|
// Send the final response
|
||||||
doneReason := "stop"
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
if seq.doneReason == "limit" {
|
Stop: true,
|
||||||
doneReason = "length"
|
StoppedLimit: seq.doneReason == "limit",
|
||||||
}
|
Timings: Timings{
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
PromptN: seq.numPromptInputs,
|
||||||
Done: true,
|
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||||
DoneReason: doneReason,
|
PredictedN: seq.numDecoded,
|
||||||
PromptEvalCount: seq.numPromptInputs,
|
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
},
|
||||||
EvalCount: seq.numDecoded,
|
|
||||||
EvalDuration: time.Since(seq.startGenerationTime),
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -669,8 +721,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
var req llm.EmbeddingRequest
|
var req EmbeddingRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@@ -700,7 +761,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
@@ -721,17 +782,41 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
embedding := <-seq.embedding
|
embedding := <-seq.embedding
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: embedding,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HealthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerStatusReady ServerStatus = iota
|
||||||
|
ServerStatusLoadingModel
|
||||||
|
ServerStatusError
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s ServerStatus) ToString() string {
|
||||||
|
switch s {
|
||||||
|
case ServerStatusReady:
|
||||||
|
return "ok"
|
||||||
|
case ServerStatusLoadingModel:
|
||||||
|
return "loading model"
|
||||||
|
default:
|
||||||
|
return "server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||||
Status: s.status,
|
Status: s.status.ToString(),
|
||||||
Progress: s.progress,
|
Progress: s.progress,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
@@ -794,7 +879,7 @@ func (s *Server) loadModel(
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.status = llm.ServerStatusReady
|
s.status = ServerStatusReady
|
||||||
s.ready.Done()
|
s.ready.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -852,7 +937,7 @@ func Execute(args []string) error {
|
|||||||
parallel: *parallel,
|
parallel: *parallel,
|
||||||
seqs: make([]*Sequence, *parallel),
|
seqs: make([]*Sequence, *parallel),
|
||||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||||
status: llm.ServerStatusLoadingModel,
|
status: ServerStatusLoadingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
var tensorSplitFloats []float32
|
var tensorSplitFloats []float32
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ type InputCacheSlot struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
@@ -107,6 +107,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cachePrompt {
|
||||||
|
numPast = 0
|
||||||
|
}
|
||||||
|
|
||||||
slot.InUse = true
|
slot.InUse = true
|
||||||
slot.lastUsed = time.Now()
|
slot.lastUsed = time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -297,131 +297,3 @@ func TestShiftDiscard(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadCacheSlot(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cache InputCache
|
|
||||||
prompt []input.Input
|
|
||||||
wantErr bool
|
|
||||||
expectedSlotId int
|
|
||||||
expectedPrompt int // expected length of remaining prompt
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Basic cache hit - single user",
|
|
||||||
cache: InputCache{
|
|
||||||
multiUserCache: false,
|
|
||||||
slots: []InputCacheSlot{
|
|
||||||
{
|
|
||||||
Id: 0,
|
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
|
||||||
InUse: false,
|
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: 1,
|
|
||||||
Inputs: []input.Input{},
|
|
||||||
InUse: false,
|
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
|
||||||
wantErr: false,
|
|
||||||
expectedSlotId: 0,
|
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Basic cache hit - multi user",
|
|
||||||
cache: InputCache{
|
|
||||||
multiUserCache: true,
|
|
||||||
slots: []InputCacheSlot{
|
|
||||||
{
|
|
||||||
Id: 0,
|
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
|
||||||
InUse: false,
|
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: 1,
|
|
||||||
Inputs: []input.Input{},
|
|
||||||
InUse: false,
|
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
|
||||||
wantErr: false,
|
|
||||||
expectedSlotId: 0,
|
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Exact match - leave one input",
|
|
||||||
cache: InputCache{
|
|
||||||
multiUserCache: false,
|
|
||||||
slots: []InputCacheSlot{
|
|
||||||
{
|
|
||||||
Id: 0,
|
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
|
||||||
InUse: false,
|
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
|
||||||
wantErr: false,
|
|
||||||
expectedSlotId: 0,
|
|
||||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No available slots",
|
|
||||||
cache: InputCache{
|
|
||||||
multiUserCache: false,
|
|
||||||
slots: []InputCacheSlot{
|
|
||||||
{
|
|
||||||
Id: 0,
|
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
|
||||||
InUse: true,
|
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
|
||||||
wantErr: true,
|
|
||||||
expectedSlotId: -1,
|
|
||||||
expectedPrompt: -1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
|
||||||
|
|
||||||
// Check error state
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
return // Skip further checks if we expected an error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify slot ID
|
|
||||||
if slot.Id != tt.expectedSlotId {
|
|
||||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify slot is now marked in use
|
|
||||||
if !slot.InUse {
|
|
||||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify remaining prompt length
|
|
||||||
if len(remainingPrompt) != tt.expectedPrompt {
|
|
||||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
|
||||||
len(remainingPrompt), tt.expectedPrompt)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
@@ -34,14 +33,10 @@ import (
|
|||||||
_ "github.com/ollama/ollama/model/models"
|
_ "github.com/ollama/ollama/model/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextList struct {
|
|
||||||
list []ml.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
type Sequence struct {
|
type Sequence struct {
|
||||||
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
||||||
// multimodal embeddings
|
// multimodal embeddings
|
||||||
ctxs *contextList
|
ctx ml.Context
|
||||||
|
|
||||||
// batch index
|
// batch index
|
||||||
iBatch int
|
iBatch int
|
||||||
@@ -99,12 +94,13 @@ type NewSequenceParams struct {
|
|||||||
embedding bool
|
embedding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
ctx := s.model.Backend().NewContext()
|
||||||
|
|
||||||
inputs, ctxs, err := s.inputs(prompt, images)
|
inputs, err := s.inputs(ctx, prompt, images)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||||
} else if len(inputs) == 0 {
|
} else if len(inputs) == 0 {
|
||||||
@@ -115,9 +111,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
params.numKeep = int32(len(inputs))
|
params.numKeep = int32(len(inputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
|
||||||
// otherwise we might truncate or split the batch against the model's wishes
|
|
||||||
|
|
||||||
// Ensure that at least 1 input can be discarded during shift
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
@@ -133,7 +126,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// TODO(jessegross): Ingest cached history for grammar
|
// TODO(jessegross): Ingest cached history for grammar
|
||||||
|
|
||||||
return &Sequence{
|
return &Sequence{
|
||||||
ctxs: ctxs,
|
ctx: ctx,
|
||||||
inputs: inputs,
|
inputs: inputs,
|
||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
@@ -152,7 +145,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
||||||
var inputs []input.Input
|
var inputs []input.Input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
@@ -167,19 +160,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
|||||||
parts = []string{prompt}
|
parts = []string{prompt}
|
||||||
}
|
}
|
||||||
|
|
||||||
var contexts contextList
|
|
||||||
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
|
||||||
for _, ctx := range ctxs {
|
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
}, contexts.list)
|
|
||||||
|
|
||||||
postTokenize := false
|
postTokenize := false
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
// text - tokenize
|
// text - tokenize
|
||||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
@@ -199,14 +185,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if imageIndex < 0 {
|
if imageIndex < 0 {
|
||||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
|
||||||
contexts.list = append(contexts.list, ctx)
|
|
||||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.multimodalHash.Reset()
|
s.multimodalHash.Reset()
|
||||||
@@ -220,13 +204,13 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
|||||||
|
|
||||||
if visionModel && postTokenize {
|
if visionModel && postTokenize {
|
||||||
var err error
|
var err error
|
||||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return inputs, &contexts, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -238,7 +222,7 @@ type Server struct {
|
|||||||
model model.Model
|
model model.Model
|
||||||
|
|
||||||
// status for external health reporting - loading, ready to serve, etc.
|
// status for external health reporting - loading, ready to serve, etc.
|
||||||
status llm.ServerStatus
|
status ServerStatus
|
||||||
|
|
||||||
// current progress on loading the model
|
// current progress on loading the model
|
||||||
progress float32
|
progress float32
|
||||||
@@ -321,6 +305,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
|||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
seq.cache.InUse = false
|
||||||
|
seq.ctx.Close()
|
||||||
s.seqs[seqIndex] = nil
|
s.seqs[seqIndex] = nil
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
@@ -366,33 +351,20 @@ func (s *Server) processBatch() error {
|
|||||||
seq.cache.Inputs = []input.Input{}
|
seq.cache.Inputs = []input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchSize := s.batchSize
|
|
||||||
|
|
||||||
for j, inp := range seq.inputs {
|
for j, inp := range seq.inputs {
|
||||||
// If we are required to put following inputs into a single batch then extend the
|
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
if len(seq.pendingInputs) == 0 {
|
||||||
// will cause a break if we have pending inputs.
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
minBatch := 1 + inp.SameBatch
|
if err != nil {
|
||||||
if minBatch > batchSize {
|
return err
|
||||||
batchSize = minBatch
|
}
|
||||||
}
|
} else {
|
||||||
|
|
||||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
|
||||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
|
||||||
// now so we don't have to do one later when we can't break the batch.
|
|
||||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
|
||||||
if len(seq.pendingInputs) != 0 {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
if j >= s.batchSize {
|
||||||
if err != nil {
|
break
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
options.Inputs = append(options.Inputs, inp.Token)
|
||||||
@@ -529,18 +501,75 @@ func (s *Server) processBatch() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||||
|
// this way the api acts as a proxy instead of using a different api for the
|
||||||
|
// runner
|
||||||
|
type Options struct {
|
||||||
|
api.Runner
|
||||||
|
|
||||||
|
NumKeep int `json:"n_keep"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NumPredict int `json:"n_predict"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
MinP float32 `json:"min_p"`
|
||||||
|
TypicalP float32 `json:"typical_p"`
|
||||||
|
RepeatLastN int `json:"repeat_last_n"`
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
|
Mirostat int `json:"mirostat"`
|
||||||
|
MirostatTau float32 `json:"mirostat_tau"`
|
||||||
|
MirostatEta float32 `json:"mirostat_eta"`
|
||||||
|
Stop []string `json:"stop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
Data []byte `json:"data"`
|
||||||
|
ID int `json:"id"`
|
||||||
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Images []ImageData `json:"image_data"`
|
||||||
|
Grammar string `json:"grammar"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
type Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
|
PromptN int `json:"prompt_n,omitempty"`
|
||||||
|
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||||
|
|
||||||
|
Timings Timings `json:"timings"`
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var req llm.CompletionRequest
|
var req CompletionRequest
|
||||||
|
req.Options = Options(api.DefaultOptions())
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Options == nil {
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
req.Options = &opts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
// Set the headers to indicate streaming
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
@@ -562,18 +591,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
req.Options.Temperature,
|
req.Temperature,
|
||||||
req.Options.TopK,
|
req.TopK,
|
||||||
req.Options.TopP,
|
req.TopP,
|
||||||
req.Options.MinP,
|
req.MinP,
|
||||||
req.Options.Seed,
|
req.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.Options.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Options.Stop,
|
stop: req.Stop,
|
||||||
numKeep: int32(req.Options.NumKeep),
|
numKeep: int32(req.NumKeep),
|
||||||
sampler: sampler,
|
sampler: sampler,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
@@ -596,7 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
@@ -623,7 +652,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
@@ -634,17 +663,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
// Send the final response
|
// Send the final response
|
||||||
doneReason := "stop"
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
if seq.doneReason == "limit" {
|
Stop: true,
|
||||||
doneReason = "length"
|
StoppedLimit: seq.doneReason == "limit",
|
||||||
}
|
Timings: Timings{
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
PromptN: seq.numPromptInputs,
|
||||||
Done: true,
|
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||||
DoneReason: doneReason,
|
PredictedN: seq.numPredicted,
|
||||||
PromptEvalCount: seq.numPromptInputs,
|
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
},
|
||||||
EvalCount: seq.numPredicted,
|
|
||||||
EvalDuration: time.Since(seq.startGenerationTime),
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -655,10 +682,43 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HealthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerStatusReady ServerStatus = iota
|
||||||
|
ServerStatusLoadingModel
|
||||||
|
ServerStatusError
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s ServerStatus) ToString() string {
|
||||||
|
switch s {
|
||||||
|
case ServerStatusReady:
|
||||||
|
return "ok"
|
||||||
|
case ServerStatusLoadingModel:
|
||||||
|
return "loading model"
|
||||||
|
default:
|
||||||
|
return "server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||||
Status: s.status,
|
Status: s.status.ToString(),
|
||||||
Progress: s.progress,
|
Progress: s.progress,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
@@ -712,7 +772,7 @@ func (s *Server) loadModel(
|
|||||||
s.seqs = make([]*Sequence, s.parallel)
|
s.seqs = make([]*Sequence, s.parallel)
|
||||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||||
|
|
||||||
s.status = llm.ServerStatusReady
|
s.status = ServerStatusReady
|
||||||
s.ready.Done()
|
s.ready.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -764,7 +824,7 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
batchSize: *batchSize,
|
batchSize: *batchSize,
|
||||||
status: llm.ServerStatusLoadingModel,
|
status: ServerStatusLoadingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): Parameters that need to be implemented:
|
// TODO(jessegross): Parameters that need to be implemented:
|
||||||
|
|||||||
@@ -87,9 +87,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
// topK also sorts the tokens in descending order of logits
|
// topK also sorts the tokens in descending order of logits
|
||||||
tokens = topK(tokens, s.topK)
|
tokens = topK(tokens, s.topK)
|
||||||
|
|
||||||
// scale and normalize the tokens in place
|
tokens = temperature(tokens, s.temperature)
|
||||||
temperature(tokens, s.temperature)
|
tokens = softmax(tokens)
|
||||||
softmax(tokens)
|
|
||||||
|
|
||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|||||||
@@ -26,16 +26,17 @@ func (h *tokenHeap) Pop() any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
// temperature applies scaling to the logits
|
||||||
func temperature(ts []token, temp float32) {
|
func temperature(ts []token, temp float32) []token {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
temp = max(temp, 1e-7)
|
temp = max(temp, 1e-7)
|
||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].value = ts[i].value / temp
|
ts[i].value = ts[i].value / temp
|
||||||
}
|
}
|
||||||
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax applies normalization to the logits
|
// softmax applies normalization to the logits
|
||||||
func softmax(ts []token) {
|
func softmax(ts []token) []token {
|
||||||
// Find max logit for numerical stability
|
// Find max logit for numerical stability
|
||||||
maxLogit := float32(math.Inf(-1))
|
maxLogit := float32(math.Inf(-1))
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
@@ -55,6 +56,8 @@ func softmax(ts []token) {
|
|||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].value /= sum
|
ts[i].value /= sum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
// topK limits the number of tokens considered to the k highest logits
|
// topK limits the number of tokens considered to the k highest logits
|
||||||
@@ -96,7 +99,6 @@ func topK(ts []token, k int) []token {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// topP limits tokens to those with cumulative probability p
|
// topP limits tokens to those with cumulative probability p
|
||||||
// requires ts to be sorted in descending order of probabilities
|
|
||||||
func topP(ts []token, p float32) []token {
|
func topP(ts []token, p float32) []token {
|
||||||
if p == 1.0 {
|
if p == 1.0 {
|
||||||
return ts
|
return ts
|
||||||
@@ -107,24 +109,37 @@ func topP(ts []token, p float32) []token {
|
|||||||
for i, t := range ts {
|
for i, t := range ts {
|
||||||
sum += t.value
|
sum += t.value
|
||||||
if sum > float32(p) {
|
if sum > float32(p) {
|
||||||
return ts[:i+1]
|
ts = ts[:i+1]
|
||||||
|
return ts
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
// minP filters tokens with probabilities >= p * max_prob
|
// minP limits tokens to those with cumulative probability p
|
||||||
// requires ts to be sorted in descending order of probabilities
|
|
||||||
func minP(ts []token, p float32) []token {
|
func minP(ts []token, p float32) []token {
|
||||||
maxProb := ts[0].value
|
if p == 1.0 {
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
threshold := maxProb * p
|
maxProb := float32(math.Inf(-1))
|
||||||
|
for _, token := range ts {
|
||||||
for i, t := range ts {
|
if token.value > maxProb {
|
||||||
if t.value < threshold {
|
maxProb = token.value
|
||||||
return ts[:i]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threshold := maxProb * float32(p)
|
||||||
|
|
||||||
|
// Filter tokens in-place
|
||||||
|
validTokens := ts[:0]
|
||||||
|
for i, token := range ts {
|
||||||
|
if token.value >= threshold {
|
||||||
|
validTokens = append(validTokens, ts[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ts = validTokens
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,22 +34,17 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
|||||||
|
|
||||||
func TestTemperature(t *testing.T) {
|
func TestTemperature(t *testing.T) {
|
||||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
tokens := toTokens(input)
|
got := temperature(toTokens(input), 0.5)
|
||||||
temperature(tokens, 0.5)
|
|
||||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
compareLogits(t, "temperature(0.5)", want, got)
|
||||||
|
|
||||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
got = temperature(toTokens(input), 1.0)
|
||||||
tokens = toTokens(input)
|
|
||||||
temperature(tokens, 1.0)
|
|
||||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
compareLogits(t, "temperature(1)", want, tokens)
|
compareLogits(t, "temperature(1)", want, got)
|
||||||
|
|
||||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
got = temperature(toTokens(input), 0.0)
|
||||||
tokens = toTokens(input)
|
|
||||||
temperature(tokens, 0.0)
|
|
||||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||||
compareLogits(t, "temperature(0)", want, tokens)
|
compareLogits(t, "temperature(0)", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestSoftmax(t *testing.T) {
|
||||||
@@ -95,17 +90,16 @@ func TestSoftmax(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tokens := toTokens(tt.input)
|
got := softmax(toTokens(tt.input))
|
||||||
softmax(tokens)
|
|
||||||
|
|
||||||
if tt.expected != nil {
|
if tt.expected != nil {
|
||||||
compareLogits(t, tt.name, tt.expected, tokens)
|
compareLogits(t, tt.name, tt.expected, got)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check probabilities sum to 1
|
// Check probabilities sum to 1
|
||||||
var sum float32
|
var sum float32
|
||||||
for _, token := range tokens {
|
for _, token := range got {
|
||||||
sum += token.value
|
sum += token.value
|
||||||
if token.value < 0 || token.value > 1 {
|
if token.value < 0 || token.value > 1 {
|
||||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||||
@@ -120,44 +114,38 @@ func TestSoftmax(t *testing.T) {
|
|||||||
|
|
||||||
func TestTopK(t *testing.T) {
|
func TestTopK(t *testing.T) {
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
tokens := toTokens(input)
|
|
||||||
tokens = topK(tokens, 5)
|
// Test k=5
|
||||||
if len(tokens) != 5 {
|
got := topK(toTokens(input), 5)
|
||||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
if len(got) != 5 {
|
||||||
|
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
||||||
}
|
}
|
||||||
|
// Should keep highest 3 values in descending order
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||||
compareLogits(t, "topK(3)", want, tokens)
|
compareLogits(t, "topK(3)", want, got)
|
||||||
|
|
||||||
tokens = toTokens(input)
|
got = topK(toTokens(input), 20)
|
||||||
tokens = topK(tokens, 20)
|
if len(got) != len(input) {
|
||||||
if len(tokens) != len(input) {
|
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
||||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test k=-1
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||||
tokens = toTokens(input)
|
got = topK(toTokens(input), -1)
|
||||||
tokens = topK(tokens, -1)
|
if len(got) != len(input) {
|
||||||
if len(tokens) != len(input) {
|
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
|
||||||
}
|
}
|
||||||
compareLogits(t, "topK(-1)", want, tokens)
|
compareLogits(t, "topK(-1)", want, got)
|
||||||
|
|
||||||
|
// Test k=0
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||||
tokens = toTokens(input)
|
got = topK(toTokens(input), 0)
|
||||||
tokens = topK(tokens, 0)
|
if len(got) != len(input) {
|
||||||
if len(tokens) != len(input) {
|
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
|
||||||
}
|
|
||||||
compareLogits(t, "topK(-1)", want, tokens)
|
|
||||||
|
|
||||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
|
||||||
tokens = toTokens(input)
|
|
||||||
tokens = topK(tokens, 1)
|
|
||||||
if len(tokens) < 1 {
|
|
||||||
t.Error("topK should keep at least one token")
|
|
||||||
}
|
}
|
||||||
|
compareLogits(t, "topK(-1)", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopP(t *testing.T) {
|
func TestTopP(t *testing.T) {
|
||||||
@@ -165,25 +153,16 @@ func TestTopP(t *testing.T) {
|
|||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax to get probabilities
|
// First apply temperature and softmax to get probabilities
|
||||||
softmax(tokens)
|
tokens = softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Then apply topP
|
||||||
tokens = topP(tokens, 0.95)
|
got := topP(tokens, 0.95)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
// Should keep tokens until cumsum > 0.95
|
||||||
if len(tokens) > 3 {
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", got)
|
||||||
}
|
|
||||||
|
|
||||||
// Test edge case - ensure at least one token remains
|
|
||||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
|
||||||
tokens = toTokens(input)
|
|
||||||
softmax(tokens)
|
|
||||||
tokens = topP(tokens, 0.0) // Very small p
|
|
||||||
if len(tokens) < 1 {
|
|
||||||
t.Error("topP should keep at least one token")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,45 +171,14 @@ func TestMinP(t *testing.T) {
|
|||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
tokens = topK(tokens, 20)
|
tokens = softmax(tokens)
|
||||||
softmax(tokens)
|
|
||||||
|
|
||||||
tokens = minP(tokens, 1.0)
|
// Then apply minP
|
||||||
|
got := minP(tokens, 0.2)
|
||||||
if len(tokens) != 1 {
|
|
||||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with normal p value
|
|
||||||
tokens = toTokens(input) // Reset tokens
|
|
||||||
tokens = topK(tokens, 20)
|
|
||||||
softmax(tokens)
|
|
||||||
tokens = minP(tokens, 0.2)
|
|
||||||
|
|
||||||
// Should keep tokens with prob >= 0.2 * max_prob
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
if len(tokens) > 3 {
|
if len(got) > 3 {
|
||||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
t.Logf("got: %v", tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with zero p value
|
|
||||||
tokens = toTokens(input) // Reset tokens
|
|
||||||
tokens = topK(tokens, 20)
|
|
||||||
softmax(tokens)
|
|
||||||
tokens = minP(tokens, 0.0)
|
|
||||||
|
|
||||||
// Should keep only the highest probability token
|
|
||||||
if len(tokens) != len(input) {
|
|
||||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
|
||||||
t.Logf("got: %v", tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
input = []float32{1e-10, 1e-10, 1e-10}
|
|
||||||
tokens = toTokens(input)
|
|
||||||
softmax(tokens)
|
|
||||||
tokens = minP(tokens, 1.0)
|
|
||||||
if len(tokens) < 1 {
|
|
||||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,7 +231,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
tokens = topK(tokensCopy, 10)
|
topK(tokensCopy, 10)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -291,7 +239,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
tokens = topP(tokensCopy, 0.9)
|
topP(tokensCopy, 0.9)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -299,7 +247,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
tokens = minP(tokensCopy, 0.2)
|
minP(tokensCopy, 0.2)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -307,7 +255,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
tokens = topK(tokensCopy, 200000)
|
topK(tokensCopy, 200000)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ usage() {
|
|||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
export VERSION=${VERSION:-$(git describe --tags --dirty)}
|
||||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||||
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'
|
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'
|
||||||
|
|
||||||
|
|||||||
54
server/internal/cache/blob/cache.go
vendored
54
server/internal/cache/blob/cache.go
vendored
@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
|
|||||||
// be in either of the following forms:
|
// be in either of the following forms:
|
||||||
//
|
//
|
||||||
// @<digest>
|
// @<digest>
|
||||||
// <name>@<digest>
|
// <name>
|
||||||
// <name>
|
// <name>
|
||||||
//
|
//
|
||||||
// If a digest is provided, it is returned as is and nothing else happens.
|
// If a digest is provided, it is returned as is and nothing else happens.
|
||||||
@@ -160,6 +160,8 @@ func debugger(err *error) func(step string) {
|
|||||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||||
// these cases.
|
// these cases.
|
||||||
|
//
|
||||||
|
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
|
||||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||||
name, digest := splitNameDigest(name)
|
name, digest := splitNameDigest(name)
|
||||||
if digest != "" {
|
if digest != "" {
|
||||||
@@ -277,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
|||||||
// It returns an error if either the name or digest is invalid, or if link
|
// It returns an error if either the name or digest is invalid, or if link
|
||||||
// creation encounters any issues.
|
// creation encounters any issues.
|
||||||
func (c *DiskCache) Link(name string, d Digest) error {
|
func (c *DiskCache) Link(name string, d Digest) error {
|
||||||
|
// TODO(bmizerany): Move link handling from cache to registry.
|
||||||
|
//
|
||||||
|
// We originally placed links in the cache due to its storage
|
||||||
|
// knowledge. However, the registry likely offers better context for
|
||||||
|
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||||
|
// our on-disk format.
|
||||||
|
//
|
||||||
|
// Links work effectively when independent from physical location -
|
||||||
|
// they can reference content with matching SHA regardless of storage
|
||||||
|
// location. In an upcoming change, we plan to shift this
|
||||||
|
// responsibility to the registry where it better aligns with the
|
||||||
|
// system's conceptual model.
|
||||||
manifest, err := c.manifestPath(name)
|
manifest, err := c.manifestPath(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -327,9 +341,7 @@ func (c *DiskCache) GetFile(d Digest) string {
|
|||||||
return absJoin(c.dir, "blobs", filename)
|
return absJoin(c.dir, "blobs", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
// Links returns a sequence of links in the cache in lexical order.
|
||||||
// Names are converted from their relative path form to their name form but are
|
|
||||||
// not guaranteed to be valid. Callers should validate the names before using.
|
|
||||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||||
return func(yield func(string, error) bool) {
|
return func(yield func(string, error) bool) {
|
||||||
for path, err := range c.links() {
|
for path, err := range c.links() {
|
||||||
@@ -402,14 +414,12 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type checkWriter struct {
|
type checkWriter struct {
|
||||||
size int64
|
|
||||||
d Digest
|
d Digest
|
||||||
f *os.File
|
size int64
|
||||||
|
n int64
|
||||||
h hash.Hash
|
h hash.Hash
|
||||||
|
f *os.File
|
||||||
w io.Writer // underlying writer; set by creator
|
err error
|
||||||
n int64
|
|
||||||
err error
|
|
||||||
|
|
||||||
testHookBeforeFinalWrite func(*os.File)
|
testHookBeforeFinalWrite func(*os.File)
|
||||||
}
|
}
|
||||||
@@ -425,10 +435,6 @@ func (w *checkWriter) seterr(err error) error {
|
|||||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||||
// hash.
|
// hash.
|
||||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||||
if w.err != nil {
|
|
||||||
return 0, w.err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := w.h.Write(p)
|
_, err := w.h.Write(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, w.seterr(err)
|
return 0, w.seterr(err)
|
||||||
@@ -447,7 +453,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
|||||||
if nextSize > w.size {
|
if nextSize > w.size {
|
||||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||||
}
|
}
|
||||||
n, err := w.w.Write(p)
|
n, err := w.f.Write(p)
|
||||||
w.n += int64(n)
|
w.n += int64(n)
|
||||||
return n, w.seterr(err)
|
return n, w.seterr(err)
|
||||||
}
|
}
|
||||||
@@ -487,12 +493,10 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
|
|||||||
|
|
||||||
// Copy file to f, but also into h to double-check hash.
|
// Copy file to f, but also into h to double-check hash.
|
||||||
cw := &checkWriter{
|
cw := &checkWriter{
|
||||||
d: out,
|
d: out,
|
||||||
size: size,
|
size: size,
|
||||||
h: sha256.New(),
|
h: sha256.New(),
|
||||||
f: f,
|
f: f,
|
||||||
w: f,
|
|
||||||
|
|
||||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||||
}
|
}
|
||||||
n, err := io.Copy(cw, file)
|
n, err := io.Copy(cw, file)
|
||||||
@@ -528,6 +532,11 @@ func splitNameDigest(s string) (name, digest string) {
|
|||||||
var errInvalidName = errors.New("invalid name")
|
var errInvalidName = errors.New("invalid name")
|
||||||
|
|
||||||
func nameToPath(name string) (_ string, err error) {
|
func nameToPath(name string) (_ string, err error) {
|
||||||
|
if strings.Contains(name, "@") {
|
||||||
|
// TODO(bmizerany): HACK: Fix names.Parse to validate.
|
||||||
|
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
|
||||||
|
return "", errInvalidName
|
||||||
|
}
|
||||||
n := names.Parse(name)
|
n := names.Parse(name)
|
||||||
if !n.IsFullyQualified() {
|
if !n.IsFullyQualified() {
|
||||||
return "", errInvalidName
|
return "", errInvalidName
|
||||||
@@ -538,7 +547,8 @@ func nameToPath(name string) (_ string, err error) {
|
|||||||
func absJoin(pp ...string) string {
|
func absJoin(pp ...string) string {
|
||||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err) // this should never happen
|
// Likely a bug bug or a bad OS problem. Just panic.
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
return abs
|
return abs
|
||||||
}
|
}
|
||||||
|
|||||||
73
server/internal/cache/blob/chunked.go
vendored
73
server/internal/cache/blob/chunked.go
vendored
@@ -1,73 +0,0 @@
|
|||||||
package blob
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Chunk represents a range of bytes in a blob.
|
|
||||||
type Chunk struct {
|
|
||||||
Start int64
|
|
||||||
End int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns end minus start plus one.
|
|
||||||
func (c Chunk) Size() int64 {
|
|
||||||
return c.End - c.Start + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chunker writes to a blob in chunks.
|
|
||||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
|
||||||
type Chunker struct {
|
|
||||||
digest Digest
|
|
||||||
size int64
|
|
||||||
f *os.File // nil means pre-validated
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
|
||||||
// size in chunks.
|
|
||||||
//
|
|
||||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
|
||||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
|
||||||
name := c.GetFile(d)
|
|
||||||
info, err := os.Stat(name)
|
|
||||||
if err == nil && info.Size() == size {
|
|
||||||
return &Chunker{}, nil
|
|
||||||
}
|
|
||||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Chunker{digest: d, size: size, f: f}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
|
||||||
// merging the data with the existing blob. It returns an error if any. As a
|
|
||||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
|
||||||
// io.ErrUnexpectedEOF.
|
|
||||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
|
||||||
if c.f == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cw := &checkWriter{
|
|
||||||
d: d,
|
|
||||||
size: chunk.Size(),
|
|
||||||
h: sha256.New(),
|
|
||||||
f: c.f,
|
|
||||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := io.CopyN(cw, r, chunk.Size())
|
|
||||||
if err != nil && errors.Is(err, io.EOF) {
|
|
||||||
return io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the underlying file.
|
|
||||||
func (c *Chunker) Close() error {
|
|
||||||
return c.f.Close()
|
|
||||||
}
|
|
||||||
4
server/internal/cache/blob/digest.go
vendored
4
server/internal/cache/blob/digest.go
vendored
@@ -63,10 +63,6 @@ func (d Digest) Short() string {
|
|||||||
return fmt.Sprintf("%x", d.sum[:4])
|
return fmt.Sprintf("%x", d.sum[:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Digest) Sum() [32]byte {
|
|
||||||
return d.sum
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Digest) Compare(other Digest) int {
|
func (d Digest) Compare(other Digest) int {
|
||||||
return slices.Compare(d.sum[:], other.sum[:])
|
return slices.Compare(d.sum[:], other.sum[:])
|
||||||
}
|
}
|
||||||
|
|||||||
78
server/internal/chunks/chunks.go
Normal file
78
server/internal/chunks/chunks.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package chunks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Chunk struct {
|
||||||
|
Start, End int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(start, end int64) Chunk {
|
||||||
|
return Chunk{start, end}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRange parses a string in the form "unit=range" where unit is a string
|
||||||
|
// and range is a string in the form "start-end". It returns the unit and the
|
||||||
|
// range as a Chunk.
|
||||||
|
func ParseRange(s string) (unit string, _ Chunk, _ error) {
|
||||||
|
unit, r, _ := strings.Cut(s, "=")
|
||||||
|
if r == "" {
|
||||||
|
return unit, Chunk{}, nil
|
||||||
|
}
|
||||||
|
c, err := Parse(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", Chunk{}, err
|
||||||
|
}
|
||||||
|
return unit, c, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses a string in the form "start-end" and returns the Chunk.
|
||||||
|
func Parse(s string) (Chunk, error) {
|
||||||
|
startStr, endStr, _ := strings.Cut(s, "-")
|
||||||
|
start, err := strconv.ParseInt(startStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid start: %v", err)
|
||||||
|
}
|
||||||
|
end, err := strconv.ParseInt(endStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid end: %v", err)
|
||||||
|
}
|
||||||
|
if start > end {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
|
||||||
|
}
|
||||||
|
return Chunk{start, end}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
|
||||||
|
// the range [0, size), in order.
|
||||||
|
func Of(size, chunkSize int64) iter.Seq[Chunk] {
|
||||||
|
return func(yield func(Chunk) bool) {
|
||||||
|
for start := int64(0); start < size; start += chunkSize {
|
||||||
|
end := min(start+chunkSize-1, size-1)
|
||||||
|
if !yield(Chunk{start, end}) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of Chunks of size chunkSize needed to cover the
|
||||||
|
// range [0, size).
|
||||||
|
func Count(size, chunkSize int64) int64 {
|
||||||
|
return (size + chunkSize - 1) / chunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns end minus start plus one.
|
||||||
|
func (c Chunk) Size() int64 {
|
||||||
|
return c.End - c.Start + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the Chunk in the form
|
||||||
|
// "{start}-{end}".
|
||||||
|
func (c Chunk) String() string {
|
||||||
|
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||||
|
}
|
||||||
65
server/internal/chunks/chunks_test.go
Normal file
65
server/internal/chunks/chunks_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package chunks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOf(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
total int64
|
||||||
|
chunkSize int64
|
||||||
|
want []Chunk
|
||||||
|
}{
|
||||||
|
{0, 1, nil},
|
||||||
|
{1, 1, []Chunk{{0, 0}}},
|
||||||
|
{1, 2, []Chunk{{0, 0}}},
|
||||||
|
{2, 1, []Chunk{{0, 0}, {1, 1}}},
|
||||||
|
{10, 9, []Chunk{{0, 8}, {9, 9}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := slices.Collect(Of(tt.total, tt.chunkSize))
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSize(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
c Chunk
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{Chunk{0, 0}, 1},
|
||||||
|
{Chunk{0, 1}, 2},
|
||||||
|
{Chunk{3, 4}, 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := tt.c.Size()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
total int64
|
||||||
|
chunkSize int64
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{0, 1, 0},
|
||||||
|
{1, 1, 1},
|
||||||
|
{1, 2, 1},
|
||||||
|
{2, 1, 2},
|
||||||
|
{10, 9, 2},
|
||||||
|
}
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := Count(tt.total, tt.chunkSize)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,13 +19,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"iter"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -37,8 +35,10 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
"github.com/ollama/ollama/server/internal/internal/syncs"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
)
|
)
|
||||||
@@ -66,7 +66,12 @@ var (
|
|||||||
const (
|
const (
|
||||||
// DefaultChunkingThreshold is the threshold at which a layer should be
|
// DefaultChunkingThreshold is the threshold at which a layer should be
|
||||||
// split up into chunks when downloading.
|
// split up into chunks when downloading.
|
||||||
DefaultChunkingThreshold = 64 << 20
|
DefaultChunkingThreshold = 128 << 20
|
||||||
|
|
||||||
|
// DefaultMaxChunkSize is the default maximum size of a chunk to
|
||||||
|
// download. It is configured based on benchmarks and aims to strike a
|
||||||
|
// balance between download speed and memory usage.
|
||||||
|
DefaultMaxChunkSize = 8 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||||
@@ -206,7 +211,8 @@ type Registry struct {
|
|||||||
// pushing or pulling models. If zero, the number of streams is
|
// pushing or pulling models. If zero, the number of streams is
|
||||||
// determined by [runtime.GOMAXPROCS].
|
// determined by [runtime.GOMAXPROCS].
|
||||||
//
|
//
|
||||||
// A negative value means no limit.
|
// Clients that want "unlimited" streams should set this to a large
|
||||||
|
// number.
|
||||||
MaxStreams int
|
MaxStreams int
|
||||||
|
|
||||||
// ChunkingThreshold is the maximum size of a layer to download in a single
|
// ChunkingThreshold is the maximum size of a layer to download in a single
|
||||||
@@ -260,7 +266,6 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rc Registry
|
var rc Registry
|
||||||
rc.UserAgent = UserAgent()
|
|
||||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -276,24 +281,25 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
return &rc, nil
|
return &rc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserAgent() string {
|
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
|
||||||
buildinfo.Main.Version,
|
|
||||||
runtime.GOARCH,
|
|
||||||
runtime.GOOS,
|
|
||||||
runtime.Version(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) maxStreams() int {
|
func (r *Registry) maxStreams() int {
|
||||||
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||||
|
|
||||||
|
// Large downloads require a writter stream, so ensure we have at least
|
||||||
|
// two streams to avoid a deadlock.
|
||||||
|
return max(n, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) maxChunkingThreshold() int64 {
|
func (r *Registry) maxChunkingThreshold() int64 {
|
||||||
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chunkSizeFor returns the chunk size for a layer of the given size. If the
|
||||||
|
// size is less than or equal to the max chunking threshold, the size is
|
||||||
|
// returned; otherwise, the max chunk size is returned.
|
||||||
|
func (r *Registry) maxChunkSize() int64 {
|
||||||
|
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
type PushParams struct {
|
type PushParams struct {
|
||||||
// From is an optional destination name for the model. If empty, the
|
// From is an optional destination name for the model. If empty, the
|
||||||
// destination name is the same as the source name.
|
// destination name is the same as the source name.
|
||||||
@@ -420,21 +426,6 @@ func canRetry(err error) bool {
|
|||||||
return re.Status >= 500
|
return re.Status >= 500
|
||||||
}
|
}
|
||||||
|
|
||||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
|
||||||
// calls the update function with the layer, the number of bytes read.
|
|
||||||
//
|
|
||||||
// It always calls update with a nil error.
|
|
||||||
type trackingReader struct {
|
|
||||||
r io.Reader
|
|
||||||
n *atomic.Int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|
||||||
n, err = r.r.Read(p)
|
|
||||||
r.n.Add(int64(n))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pull pulls the model with the given name from the remote registry into the
|
// Pull pulls the model with the given name from the remote registry into the
|
||||||
// cache.
|
// cache.
|
||||||
//
|
//
|
||||||
@@ -443,6 +434,11 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||||
|
scheme, n, _, err := r.parseNameExtended(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
m, err := r.Resolve(ctx, name)
|
m, err := r.Resolve(ctx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -461,95 +457,126 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err == nil && info.Size == l.Size
|
return err == nil && info.Size == l.Size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t := traceFromContext(ctx)
|
||||||
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(r.maxStreams())
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
for _, l := range layers {
|
||||||
// understanding of work to be done before work starts.
|
|
||||||
t := traceFromContext(ctx)
|
|
||||||
skip := make([]bool, len(layers))
|
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
|
||||||
if exists(l) {
|
if exists(l) {
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
t.update(l, l.Size, ErrCached)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
|
||||||
g.SetLimit(r.maxStreams())
|
|
||||||
for i, l := range layers {
|
|
||||||
if skip[i] {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
|
||||||
|
req, err := r.newRequest(ctx, "GET", blobURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
t.update(l, 0, nil)
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
|
||||||
if err != nil {
|
if l.Size <= r.maxChunkingThreshold() {
|
||||||
t.update(l, progress.Load(), err)
|
g.Go(func() error {
|
||||||
break
|
// TODO(bmizerany): retry/backoff like below in
|
||||||
}
|
// the chunking case
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
err = c.Put(l.Digest, res.Body, l.Size)
|
||||||
|
if err == nil {
|
||||||
|
t.update(l, l.Size, nil)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
q := syncs.NewRelayReader()
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() { q.CloseWithError(err) }()
|
||||||
|
return c.Put(l.Digest, q, l.Size)
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err := func() error {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
// Count bytes towards
|
|
||||||
// progress, as they arrive, so
|
|
||||||
// that our bytes piggyback
|
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var progress atomic.Int64
|
||||||
|
|
||||||
|
// We want to avoid extra round trips per chunk due to
|
||||||
|
// redirects from the registry to the blob store, so
|
||||||
|
// fire an initial request to get the final URL and
|
||||||
|
// then use that URL for the chunk requests.
|
||||||
|
req.Header.Set("Range", "bytes=0-0")
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
req = res.Request.WithContext(req.Context())
|
||||||
|
|
||||||
|
wp := writerPool{size: r.maxChunkSize()}
|
||||||
|
|
||||||
|
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket := q.Take()
|
||||||
|
g.Go(func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
q.CloseWithError(err)
|
||||||
|
}
|
||||||
|
ticket.Close()
|
||||||
|
t.update(l, progress.Load(), err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err := func() error {
|
||||||
|
req := req.Clone(req.Context())
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
tw := wp.get()
|
||||||
|
tw.Reset(ticket)
|
||||||
|
defer wp.put(tw)
|
||||||
|
|
||||||
|
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||||
|
if err != nil {
|
||||||
|
return maybeUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
if err := tw.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
total := progress.Add(chunk.Size())
|
||||||
|
if total >= l.Size {
|
||||||
|
q.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if !canRetry(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -588,6 +615,8 @@ type Manifest struct {
|
|||||||
Config *Layer `json:"config"`
|
Config *Layer `json:"config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||||
|
|
||||||
// Layer returns the layer with the given
|
// Layer returns the layer with the given
|
||||||
// digest, or nil if not found.
|
// digest, or nil if not found.
|
||||||
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||||
@@ -614,9 +643,10 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
|
|||||||
// last phase of the commit which expects it, but does nothing
|
// last phase of the commit which expects it, but does nothing
|
||||||
// with it. This will be fixed in a future release of
|
// with it. This will be fixed in a future release of
|
||||||
// ollama.com.
|
// ollama.com.
|
||||||
Config Layer `json:"config"`
|
Config *Layer `json:"config"`
|
||||||
}{
|
}{
|
||||||
M: M(m),
|
M: M(m),
|
||||||
|
Config: &Layer{Digest: emptyDigest},
|
||||||
}
|
}
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
@@ -706,123 +736,6 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type chunksum struct {
|
|
||||||
URL string
|
|
||||||
Chunk blob.Chunk
|
|
||||||
Digest blob.Digest
|
|
||||||
}
|
|
||||||
|
|
||||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
|
||||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
|
||||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
|
||||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
|
||||||
return func(yield func(chunksum, error) bool) {
|
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if l.Size < r.maxChunkingThreshold() {
|
|
||||||
// any layer under the threshold should be downloaded
|
|
||||||
// in one go.
|
|
||||||
cs := chunksum{
|
|
||||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
|
||||||
scheme,
|
|
||||||
n.Host(),
|
|
||||||
n.Namespace(),
|
|
||||||
n.Model(),
|
|
||||||
l.Digest,
|
|
||||||
),
|
|
||||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
|
||||||
Digest: l.Digest,
|
|
||||||
}
|
|
||||||
yield(cs, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// A chunksums response is a sequence of chunksums in a
|
|
||||||
// simple, easy to parse line-oriented format.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
|
||||||
//
|
|
||||||
// << HTTP/1.1 200 OK
|
|
||||||
// << Content-Location: <blobURL>
|
|
||||||
// <<
|
|
||||||
// << <digest> <start>-<end>
|
|
||||||
// << ...
|
|
||||||
//
|
|
||||||
// The blobURL is the URL to download the chunks from.
|
|
||||||
|
|
||||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
|
||||||
scheme,
|
|
||||||
n.Host(),
|
|
||||||
n.Namespace(),
|
|
||||||
n.Model(),
|
|
||||||
l.Digest,
|
|
||||||
)
|
|
||||||
|
|
||||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
blobURL := res.Header.Get("Content-Location")
|
|
||||||
|
|
||||||
s := bufio.NewScanner(res.Body)
|
|
||||||
s.Split(bufio.ScanWords)
|
|
||||||
for {
|
|
||||||
if !s.Scan() {
|
|
||||||
if s.Err() != nil {
|
|
||||||
yield(chunksum{}, s.Err())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d, err := blob.ParseDigest(s.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.Scan() {
|
|
||||||
err := s.Err()
|
|
||||||
if err == nil {
|
|
||||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
|
||||||
}
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
chunk, err := parseChunk(s.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := chunksum{
|
|
||||||
URL: blobURL,
|
|
||||||
Chunk: chunk,
|
|
||||||
Digest: d,
|
|
||||||
}
|
|
||||||
if !yield(cs, nil) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) client() *http.Client {
|
func (r *Registry) client() *http.Client {
|
||||||
if r.HTTPClient != nil {
|
if r.HTTPClient != nil {
|
||||||
return r.HTTPClient
|
return r.HTTPClient
|
||||||
@@ -985,6 +898,13 @@ func checkData(url string) string {
|
|||||||
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func maybeUnexpectedEOF(err error) error {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
type publicError struct {
|
type publicError struct {
|
||||||
wrapped error
|
wrapped error
|
||||||
message string
|
message string
|
||||||
@@ -1071,22 +991,27 @@ func splitExtended(s string) (scheme, name, digest string) {
|
|||||||
return scheme, s, digest
|
return scheme, s, digest
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
type writerPool struct {
|
||||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
size int64 // set by the caller
|
||||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
|
||||||
if !found {
|
mu sync.Mutex
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
ws []*bufio.Writer
|
||||||
}
|
}
|
||||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
|
||||||
if err != nil {
|
func (p *writerPool) get() *bufio.Writer {
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
p.mu.Lock()
|
||||||
}
|
defer p.mu.Unlock()
|
||||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
if len(p.ws) == 0 {
|
||||||
if err != nil {
|
return bufio.NewWriterSize(nil, int(p.size))
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
}
|
||||||
}
|
w := p.ws[len(p.ws)-1]
|
||||||
if start > end {
|
p.ws = p.ws[:len(p.ws)-1]
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
return w
|
||||||
}
|
}
|
||||||
return blob.Chunk{Start: start, End: end}, nil
|
|
||||||
|
func (p *writerPool) put(w *bufio.Writer) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
w.Reset(nil)
|
||||||
|
p.ws = append(p.ws, w)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
err := rc.Pull(ctx, "single")
|
err := rc.Pull(ctx, "single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{0, 6}
|
want := []int64{6}
|
||||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||||
}
|
}
|
||||||
@@ -530,6 +531,54 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegistryPullChunking(t *testing.T) {
|
||||||
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||||
|
if r.URL.Host != "blob.store" {
|
||||||
|
// The production registry redirects to the blob store.
|
||||||
|
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
|
rng := r.Header.Get("Range")
|
||||||
|
if rng == "" {
|
||||||
|
http.Error(w, "missing range", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
io.WriteString(w, "remote"[c.Start:c.End+1])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Force chunking by setting the threshold to less than the size of the
|
||||||
|
// layer.
|
||||||
|
rc.ChunkingThreshold = 3
|
||||||
|
rc.MaxChunkSize = 3
|
||||||
|
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(d *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("update %v %d %v", d, n, err)
|
||||||
|
}
|
||||||
|
reads = append(reads, n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := rc.Pull(ctx, "remote")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
want := []int64{0, 3, 6}
|
||||||
|
if !slices.Equal(reads, want) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegistryResolveByDigest(t *testing.T) {
|
func TestRegistryResolveByDigest(t *testing.T) {
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
|||||||
11
server/internal/cmd/oppbench/oppbench.go
Normal file
11
server/internal/cmd/oppbench/oppbench.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
107
server/internal/cmd/oppbench/oppbench_test.go
Normal file
107
server/internal/cmd/oppbench/oppbench_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkDownload(b *testing.B) {
|
||||||
|
run := func(fileSize, chunkSize int64) {
|
||||||
|
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
|
||||||
|
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
|
||||||
|
}
|
||||||
|
|
||||||
|
run(100<<20, 8<<20)
|
||||||
|
run(100<<20, 16<<20)
|
||||||
|
run(100<<20, 32<<20)
|
||||||
|
run(100<<20, 64<<20)
|
||||||
|
run(100<<20, 128<<20) // 1 chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
|
||||||
|
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||||
|
res, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var sleepTime atomic.Int64
|
||||||
|
|
||||||
|
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: func() http.RoundTripper {
|
||||||
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
tr.DisableKeepAlives = true
|
||||||
|
return tr
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
defer client.CloseIdleConnections()
|
||||||
|
|
||||||
|
// warm up the client
|
||||||
|
run(context.Background(), client, chunks.New(0, 1<<20))
|
||||||
|
|
||||||
|
b.SetBytes(fileSize)
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
// Give our CDN a min to breathe between benchmarks.
|
||||||
|
time.Sleep(time.Duration(sleepTime.Swap(3)))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
g, ctx := errgroup.WithContext(b.Context())
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
|
for chunk := range chunks.Of(fileSize, chunkSize) {
|
||||||
|
g.Go(func() error { return run(ctx, client, chunk) })
|
||||||
|
}
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWrite(b *testing.B) {
|
||||||
|
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkWrite(b *testing.B, chunkSize int) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
dir := b.TempDir()
|
||||||
|
f, err := os.Create(filepath.Join(dir, "write-single"))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
data := make([]byte, chunkSize)
|
||||||
|
b.SetBytes(int64(chunkSize))
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
for b.Loop() {
|
||||||
|
r.Reset(data)
|
||||||
|
_, err := io.Copy(f, r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// Package registry implements an http.Handler for handling local Ollama API
|
// Package registry provides an http.Handler for handling local Ollama API
|
||||||
// model management requests. See [Local] for details.
|
// requests for performing tasks related to the ollama.com model registry and
|
||||||
|
// the local disk cache.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,11 +18,16 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Local implements an http.Handler for handling local Ollama API model
|
// Local is an http.Handler for handling local Ollama API requests for
|
||||||
// management requests, such as pushing, pulling, and deleting models.
|
// performing tasks related to the ollama.com model registry combined with the
|
||||||
|
// local disk cache.
|
||||||
//
|
//
|
||||||
// It can be arranged for all unknown requests to be passed through to a
|
// It is not concern of Local, or this package, to handle model creation, which
|
||||||
// fallback handler, if one is provided.
|
// proceeds any registry operations for models it produces.
|
||||||
|
//
|
||||||
|
// NOTE: The package built for dealing with model creation should use
|
||||||
|
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||||
|
// directly to the blob disk cache.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
Client *ollama.Registry // required
|
Client *ollama.Registry // required
|
||||||
Logger *slog.Logger // required
|
Logger *slog.Logger // required
|
||||||
@@ -58,7 +63,6 @@ func (e serverError) Error() string {
|
|||||||
var (
|
var (
|
||||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||||
errNotFound = &serverError{404, "not_found", "not found"}
|
errNotFound = &serverError{404, "not_found", "not found"}
|
||||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
|
||||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,16 +175,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type params struct {
|
type params struct {
|
||||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
DeprecatedName string `json:"name"` // Use [params.model]
|
||||||
// but is deprecated. New clients should use [Model] instead.
|
Model string `json:"model"` // Use [params.model]
|
||||||
//
|
|
||||||
// Use [model()] to get the model name for both old and new API requests.
|
|
||||||
DeprecatedName string `json:"name"`
|
|
||||||
|
|
||||||
// Model is the name of the model to push, pull, or delete.
|
|
||||||
//
|
|
||||||
// Use [model()] to get the model name for both old and new API requests.
|
|
||||||
Model string `json:"model"`
|
|
||||||
|
|
||||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||||
// is doing so, deliberately.
|
// is doing so, deliberately.
|
||||||
@@ -193,18 +189,9 @@ type params struct {
|
|||||||
// confusing flags such as this.
|
// confusing flags such as this.
|
||||||
AllowNonTLS bool `json:"insecure"`
|
AllowNonTLS bool `json:"insecure"`
|
||||||
|
|
||||||
// Stream, if true, will make the server send progress updates in a
|
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||||
// streaming of JSON objects. If false, the server will send a single
|
// progress updates.
|
||||||
// JSON object with the final status as "success", or an error object
|
ProgressStream bool `json:"stream"`
|
||||||
// if an error occurred.
|
|
||||||
//
|
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
|
||||||
// defined to default to true if not present, so we need a way to check
|
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
|
||||||
// bool. Gross.
|
|
||||||
//
|
|
||||||
// Use [stream()] to get the correct value for this field.
|
|
||||||
Stream *bool `json:"stream"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// model returns the model name for both old and new API requests.
|
// model returns the model name for both old and new API requests.
|
||||||
@@ -212,13 +199,6 @@ func (p params) model() string {
|
|||||||
return cmp.Or(p.Model, p.DeprecatedName)
|
return cmp.Or(p.Model, p.DeprecatedName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p params) stream() bool {
|
|
||||||
if p.Stream == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return *p.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||||
if r.Method != "DELETE" {
|
if r.Method != "DELETE" {
|
||||||
return errMethodNotAllowed
|
return errMethodNotAllowed
|
||||||
@@ -232,16 +212,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return errModelNotFound
|
return &serverError{404, "not_found", "model not found"}
|
||||||
}
|
}
|
||||||
if s.Prune != nil {
|
if s.Prune == nil {
|
||||||
return s.Prune()
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return s.Prune()
|
||||||
}
|
}
|
||||||
|
|
||||||
type progressUpdateJSON struct {
|
type progressUpdateJSON struct {
|
||||||
Status string `json:"status,omitempty,omitzero"`
|
Status string `json:"status"`
|
||||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||||
Total int64 `json:"total,omitempty,omitzero"`
|
Total int64 `json:"total,omitempty,omitzero"`
|
||||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||||
@@ -257,17 +237,6 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
if !p.stream() {
|
|
||||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
|
||||||
return errModelNotFound
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return enc.Encode(progressUpdateJSON{Status: "success"})
|
|
||||||
}
|
|
||||||
|
|
||||||
maybeFlush := func() {
|
maybeFlush := func() {
|
||||||
fl, _ := w.(http.Flusher)
|
fl, _ := w.(http.Flusher)
|
||||||
if fl != nil {
|
if fl != nil {
|
||||||
@@ -277,67 +246,69 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
progress := make(map[*ollama.Layer]int64)
|
enc := json.NewEncoder(w)
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
pushUpdate := func() {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
defer maybeFlush()
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): coalesce these updates; writing per
|
||||||
// needing to flush out them all in one big update. We _could_
|
// update is expensive
|
||||||
// just flush on the changed ones, or just track the whole
|
|
||||||
// download. Needs more thought. This is fine for now.
|
|
||||||
mu.Lock()
|
|
||||||
maps.Copy(progressCopy, progress)
|
|
||||||
mu.Unlock()
|
|
||||||
for l, n := range progress {
|
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
|
Status: "pulling",
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
Completed: n,
|
Completed: n,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
|
||||||
start := sync.OnceFunc(func() {
|
|
||||||
pushUpdate()
|
|
||||||
t.Reset(100 * time.Millisecond)
|
|
||||||
})
|
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
|
||||||
if n > 0 {
|
|
||||||
start() // flush initial state
|
|
||||||
}
|
|
||||||
mu.Lock()
|
|
||||||
progress[l] = n
|
|
||||||
mu.Unlock()
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
// TODO(bmizerany): continue to support non-streaming responses
|
||||||
done <- s.Client.Pull(ctx, p.model())
|
done <- s.Client.Pull(ctx, p.model())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
func() {
|
||||||
select {
|
t := time.NewTicker(100 * time.Millisecond)
|
||||||
case <-t.C:
|
defer t.Stop()
|
||||||
pushUpdate()
|
for {
|
||||||
case err := <-done:
|
select {
|
||||||
pushUpdate()
|
case <-t.C:
|
||||||
if err != nil {
|
mu.Lock()
|
||||||
var status string
|
maybeFlush()
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
mu.Unlock()
|
||||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
case err := <-done:
|
||||||
} else {
|
if err != nil {
|
||||||
status = fmt.Sprintf("error: %v", err)
|
var status string
|
||||||
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||||
|
enc.Encode(progressUpdateJSON{Status: status})
|
||||||
|
} else {
|
||||||
|
status = fmt.Sprintf("error: %v", err)
|
||||||
|
enc.Encode(progressUpdateJSON{Status: status})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
enc.Encode(progressUpdateJSON{Status: status})
|
|
||||||
|
// These final updates are not strictly necessary, because they have
|
||||||
|
// already happened at this point. Our pull handler code used to do
|
||||||
|
// these steps after, not during, the pull, and they were slow, so we
|
||||||
|
// wanted to provide feedback to users what was happening. For now, we
|
||||||
|
// keep them to not jar users who are used to seeing them. We can phase
|
||||||
|
// them out with a new and nicer UX later. One without progress bars
|
||||||
|
// and digests that no one cares about.
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
@@ -159,6 +160,7 @@ var registryFS = sync.OnceValue(func() fs.FS {
|
|||||||
// to \n when parsing the txtar on Windows.
|
// to \n when parsing the txtar on Windows.
|
||||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||||
a := txtar.Parse(data)
|
a := txtar.Parse(data)
|
||||||
|
fmt.Printf("%q\n", a.Comment)
|
||||||
fsys, err := txtar.FS(a)
|
fsys, err := txtar.FS(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -177,7 +179,7 @@ func TestServerPull(t *testing.T) {
|
|||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||||
default:
|
default:
|
||||||
t.Logf("serving blob: %s", r.URL.Path)
|
t.Logf("serving file: %s", r.URL.Path)
|
||||||
modelsHandler.ServeHTTP(w, r)
|
modelsHandler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -186,7 +188,7 @@ func TestServerPull(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if got.Code != 200 {
|
if got.Code != 200 {
|
||||||
t.Errorf("Code = %d; want 200", got.Code)
|
t.Fatalf("Code = %d; want 200", got.Code)
|
||||||
}
|
}
|
||||||
gotlines := got.Body.String()
|
gotlines := got.Body.String()
|
||||||
t.Logf("got:\n%s", gotlines)
|
t.Logf("got:\n%s", gotlines)
|
||||||
@@ -195,29 +197,35 @@ func TestServerPull(t *testing.T) {
|
|||||||
want, unwanted := strings.CutPrefix(want, "!")
|
want, unwanted := strings.CutPrefix(want, "!")
|
||||||
want = strings.TrimSpace(want)
|
want = strings.TrimSpace(want)
|
||||||
if !unwanted && !strings.Contains(gotlines, want) {
|
if !unwanted && !strings.Contains(gotlines, want) {
|
||||||
t.Errorf("! missing %q in body", want)
|
t.Fatalf("! missing %q in body", want)
|
||||||
}
|
}
|
||||||
if unwanted && strings.Contains(gotlines, want) {
|
if unwanted && strings.Contains(gotlines, want) {
|
||||||
t.Errorf("! unexpected %q in body", want)
|
t.Fatalf("! unexpected %q in body", want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
{"status":"pulling manifest"}
|
||||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||||
|
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||||
|
{"status":"verifying layers"}
|
||||||
|
{"status":"writing manifest"}
|
||||||
|
{"status":"success"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: model \"unknown\" not found"}
|
{"status":"error: model \"unknown\" not found"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
@@ -232,39 +240,19 @@ func TestServerPull(t *testing.T) {
|
|||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: invalid or missing name: \"\""}
|
{"status":"error: invalid or missing name: \"\""}
|
||||||
`)
|
|
||||||
|
|
||||||
// Non-streaming pulls
|
!verifying
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
!writing
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
!success
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"success"}
|
|
||||||
!digest
|
|
||||||
!total
|
|
||||||
!completed
|
|
||||||
`)
|
`)
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
|
||||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerUnknownPath(t *testing.T) {
|
func TestServerUnknownPath(t *testing.T) {
|
||||||
s := newTestServer(t, nil)
|
s := newTestServer(t, nil)
|
||||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||||
|
|
||||||
var fellback bool
|
|
||||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fellback = true
|
|
||||||
})
|
|
||||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
|
||||||
if !fellback {
|
|
||||||
t.Fatal("expected Fallback to be called")
|
|
||||||
}
|
|
||||||
if got.Code != 200 {
|
|
||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
var system []api.Message
|
var system []api.Message
|
||||||
|
|
||||||
isMllama := checkMllamaModelFamily(m)
|
isMllama := checkMllamaModelFamily(m)
|
||||||
|
isGemma3 := checkGemma3ModelFamily(m)
|
||||||
|
|
||||||
var imageNumTokens int
|
var imageNumTokens int
|
||||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||||
@@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
// in reverse, find all messages that fit into context window
|
// in reverse, find all messages that fit into context window
|
||||||
for i := n; i >= 0; i-- {
|
for i := n; i >= 0; i-- {
|
||||||
if isMllama && len(msgs[i].Images) > 1 {
|
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
|
||||||
return "", nil, errTooManyImages
|
return "", nil, errTooManyImages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkGemma3ModelFamily(m *Model) bool {
|
||||||
|
for _, arch := range m.Config.ModelFamilies {
|
||||||
|
if arch == "gemma3" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user