Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
73a1e99f8a logging: add a new customer logger and trace method
This change addresses over logging with debug in the SPM tokenizer by
adding a trace level to slog.
2025-03-13 16:10:59 -07:00
16 changed files with 675 additions and 513 deletions

View File

@@ -187,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`. Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform. Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored? ## Where are models stored?

View File

@@ -583,52 +583,39 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "mllama": case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8 numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 + graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles + imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles + embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
} }
return weights, graphSize return weights, graphSize
} }

View File

@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size() layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount() layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
} }
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU { if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU // Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights // memory of the weights
"total", format.HumanBytes2(m.memoryWeights), "total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers // memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights), "repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
// memory of non-repeating layers // memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput), "nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
), ),

40
logging/log.go Normal file
View File

@@ -0,0 +1,40 @@
package logging
import (
"context"
"log/slog"
"os"
)
const LevelTrace slog.Level = slog.LevelDebug - 4
type Logger struct {
logger *slog.Logger
}
func NewLogger() *Logger {
handler := slog.NewTextHandler(os.Stdout, nil)
return &Logger{
logger: slog.New(handler),
}
}
func (l *Logger) Trace(msg string, args ...any) {
l.logger.Log(context.Background(), LevelTrace, msg, args...)
}
func (l *Logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *Logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *Logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *Logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}

View File

@@ -2,15 +2,18 @@ package model
import ( import (
"iter" "iter"
"log/slog"
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue" queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
"github.com/ollama/ollama/logging"
) )
const spmWhitespaceSep = "▁" const spmWhitespaceSep = "▁"
var log = logging.NewLogger()
func replaceWhitespaceBySeperator(s string) string { func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep) return strings.ReplaceAll(s, " ", spmWhitespaceSep)
} }
@@ -24,7 +27,7 @@ type SentencePieceModel struct {
var _ TextProcessor = (*SentencePieceModel)(nil) var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
var maxTokenLen int var maxTokenLen int
@@ -38,7 +41,7 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
} }
} }
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
@@ -91,7 +94,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
} }
} }
slog.Debug("fragments", "frags", fragments) log.Trace("fragments", "frags", fragments)
var ids []int32 var ids []int32
for _, frag := range fragments { for _, frag := range fragments {
@@ -129,7 +132,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
} }
} }
slog.Debug("tokenizer", "merges", merges) log.Trace("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate { pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) { if a < 0 || b >= len(runes) {
@@ -156,7 +159,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
pqv := pq.Values() pqv := pq.Values()
for _, v := range pqv { for _, v := range pqv {
e := v.(*candidate) e := v.(*candidate)
slog.Debug("candidate", "candidate", e) log.Trace("candidate", "candidate", e)
} }
for !pq.Empty() { for !pq.Empty() {
@@ -164,7 +167,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
pair := v.(*candidate) pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b] left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right) log.Trace("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 { if len(left.runes) == 0 || len(right.runes) == 0 {
continue continue
} }
@@ -189,14 +192,14 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
} }
} }
slog.Debug("merges", "merges", merges) log.Trace("merges", "merges", merges)
for _, merge := range merges { for _, merge := range merges {
if len(merge.runes) > 0 { if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
} else { } else {
slog.Debug("missing token", "token", string(merge.runes)) log.Error("missing token", "token", string(merge.runes))
} }
} }
} }
@@ -206,19 +209,19 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
if addSpecial && len(ids) > 0 { if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS { if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS { if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS) log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
} }
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS) log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...) ids = append([]int32{spm.vocab.BOS}, ids...)
} }
if spm.vocab.AddEOS { if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS { if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS) log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
} }
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS) log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS) ids = append(ids, spm.vocab.EOS)
} }
} }
@@ -241,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
} }
} }
slog.Debug("decoded", "ids", ids, "text", sb.String()) log.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil return sb.String(), nil
} }

View File

@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
// be in either of the following forms: // be in either of the following forms:
// //
// @<digest> // @<digest>
// <name>@<digest> // <name>
// <name> // <name>
// //
// If a digest is provided, it is returned as is and nothing else happens. // If a digest is provided, it is returned as is and nothing else happens.
@@ -160,6 +160,8 @@ func debugger(err *error) func(step string) {
// hashed is passed to a PutBytes call to ensure that the manifest is in the // hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in // blob store. This is done to ensure that future calls to [Get] succeed in
// these cases. // these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) { func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name) name, digest := splitNameDigest(name)
if digest != "" { if digest != "" {
@@ -277,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link // It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues. // creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error { func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name) manifest, err := c.manifestPath(name)
if err != nil { if err != nil {
return err return err
@@ -327,9 +341,7 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename) return absJoin(c.dir, "blobs", filename)
} }
// Links returns a sequence of link names. The sequence is in lexical order. // Links returns a sequence of links in the cache in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] { func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) { return func(yield func(string, error) bool) {
for path, err := range c.links() { for path, err := range c.links() {
@@ -402,14 +414,12 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
} }
type checkWriter struct { type checkWriter struct {
size int64
d Digest d Digest
f *os.File size int64
n int64
h hash.Hash h hash.Hash
f *os.File
w io.Writer // underlying writer; set by creator err error
n int64
err error
testHookBeforeFinalWrite func(*os.File) testHookBeforeFinalWrite func(*os.File)
} }
@@ -425,10 +435,6 @@ func (w *checkWriter) seterr(err error) error {
// underlying writer is guaranteed to be the last byte of p as verified by the // underlying writer is guaranteed to be the last byte of p as verified by the
// hash. // hash.
func (w *checkWriter) Write(p []byte) (int, error) { func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p) _, err := w.h.Write(p)
if err != nil { if err != nil {
return 0, w.seterr(err) return 0, w.seterr(err)
@@ -447,7 +453,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
if nextSize > w.size { if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size)) return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
} }
n, err := w.w.Write(p) n, err := w.f.Write(p)
w.n += int64(n) w.n += int64(n)
return n, w.seterr(err) return n, w.seterr(err)
} }
@@ -487,12 +493,10 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
// Copy file to f, but also into h to double-check hash. // Copy file to f, but also into h to double-check hash.
cw := &checkWriter{ cw := &checkWriter{
d: out, d: out,
size: size, size: size,
h: sha256.New(), h: sha256.New(),
f: f, f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite, testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
} }
n, err := io.Copy(cw, file) n, err := io.Copy(cw, file)
@@ -528,6 +532,11 @@ func splitNameDigest(s string) (name, digest string) {
var errInvalidName = errors.New("invalid name") var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) { func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name) n := names.Parse(name)
if !n.IsFullyQualified() { if !n.IsFullyQualified() {
return "", errInvalidName return "", errInvalidName
@@ -538,7 +547,8 @@ func nameToPath(name string) (_ string, err error) {
func absJoin(pp ...string) string { func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...)) abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil { if err != nil {
panic(err) // this should never happen // Likely a bug bug or a bad OS problem. Just panic.
panic(err)
} }
return abs return abs
} }

View File

@@ -1,73 +0,0 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
)
// Chunk represents a range of bytes in a blob.
type Chunk struct {
Start int64
End int64
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

View File

@@ -63,10 +63,6 @@ func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4]) return fmt.Sprintf("%x", d.sum[:4])
} }
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int { func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:]) return slices.Compare(d.sum[:], other.sum[:])
} }

View File

@@ -0,0 +1,78 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@@ -0,0 +1,65 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@@ -19,7 +19,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"iter"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
@@ -36,8 +35,10 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff" "github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names" "github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed" _ "embed"
) )
@@ -65,7 +66,12 @@ var (
const ( const (
// DefaultChunkingThreshold is the threshold at which a layer should be // DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading. // split up into chunks when downloading.
DefaultChunkingThreshold = 64 << 20 DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
) )
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
@@ -205,7 +211,8 @@ type Registry struct {
// pushing or pulling models. If zero, the number of streams is // pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS]. // determined by [runtime.GOMAXPROCS].
// //
// A negative value means no limit. // Clients that want "unlimited" streams should set this to a large
// number.
MaxStreams int MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single // ChunkingThreshold is the maximum size of a layer to download in a single
@@ -275,13 +282,24 @@ func DefaultRegistry() (*Registry, error) {
} }
func (r *Registry) maxStreams() int { func (r *Registry) maxStreams() int {
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
} }
func (r *Registry) maxChunkingThreshold() int64 { func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold) return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
} }
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
type PushParams struct { type PushParams struct {
// From is an optional destination name for the model. If empty, the // From is an optional destination name for the model. If empty, the
// destination name is the same as the source name. // destination name is the same as the source name.
@@ -408,21 +426,6 @@ func canRetry(err error) bool {
return re.Status >= 500 return re.Status >= 500
} }
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
r io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.n.Add(int64(n))
return
}
// Pull pulls the model with the given name from the remote registry into the // Pull pulls the model with the given name from the remote registry into the
// cache. // cache.
// //
@@ -431,6 +434,11 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
// typically slower than splitting the model up across layers, and is mostly // typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image". // utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error { func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name) m, err := r.Resolve(ctx, name)
if err != nil { if err != nil {
return err return err
@@ -449,95 +457,126 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err == nil && info.Size == l.Size return err == nil && info.Size == l.Size
} }
t := traceFromContext(ctx)
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
layers := m.Layers layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() { if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config) layers = append(layers, m.Config)
} }
// Send initial layer trace events to allow clients to have an for _, l := range layers {
// understanding of work to be done before work starts.
t := traceFromContext(ctx)
skip := make([]bool, len(layers))
for i, l := range layers {
t.update(l, 0, nil)
if exists(l) { if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
}
}
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
for i, l := range layers {
if skip[i] {
continue continue
} }
chunked, err := c.Chunked(l.Digest, l.Size) blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
if err != nil { if err != nil {
t.update(l, 0, err) t.update(l, 0, err)
continue continue
} }
defer chunked.Close()
var progress atomic.Int64 t.update(l, 0, nil)
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if l.Size <= r.maxChunkingThreshold() {
t.update(l, progress.Load(), err) g.Go(func() error {
break // TODO(bmizerany): retry/backoff like below in
} // the chunking case
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }() defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards
// progress, as they arrive, so
// that our bytes piggyback
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
}) })
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
wp := writerPool{size: r.maxChunkSize()}
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take()
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := wp.get()
tw.Reset(ticket)
defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
} }
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
@@ -576,6 +615,8 @@ type Manifest struct {
Config *Layer `json:"config"` Config *Layer `json:"config"`
} }
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given // Layer returns the layer with the given
// digest, or nil if not found. // digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer { func (m *Manifest) Layer(d blob.Digest) *Layer {
@@ -602,9 +643,10 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
// last phase of the commit which expects it, but does nothing // last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of // with it. This will be fixed in a future release of
// ollama.com. // ollama.com.
Config Layer `json:"config"` Config *Layer `json:"config"`
}{ }{
M: M(m), M: M(m),
Config: &Layer{Digest: emptyDigest},
} }
return json.Marshal(v) return json.Marshal(v)
} }
@@ -694,123 +736,6 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
return m, nil return m, nil
} }
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
//
// Example:
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
//
// The blobURL is the URL to download the chunks from.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client { func (r *Registry) client() *http.Client {
if r.HTTPClient != nil { if r.HTTPClient != nil {
return r.HTTPClient return r.HTTPClient
@@ -973,6 +898,13 @@ func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum) return fmt.Sprintf("GET,%s,%s", url, zeroSum)
} }
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
type publicError struct { type publicError struct {
wrapped error wrapped error
message string message string
@@ -1059,22 +991,27 @@ func splitExtended(s string) (scheme, name, digest string) {
return scheme, s, digest return scheme, s, digest
} }
// parseChunk parses a string in the form "start-end" and returns the Chunk. type writerPool struct {
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) { size int64 // set by the caller
startPart, endPart, found := strings.Cut(string(s), "-")
if !found { mu sync.Mutex
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s) ws []*bufio.Writer
} }
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil { func (p *writerPool) get() *bufio.Writer {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err) p.mu.Lock()
} defer p.mu.Unlock()
end, err := strconv.ParseInt(endPart, 10, 64) if len(p.ws) == 0 {
if err != nil { return bufio.NewWriterSize(nil, int(p.size))
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err) }
} w := p.ws[len(p.ws)-1]
if start > end { p.ws = p.ws[:len(p.ws)-1]
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s) return w
} }
return blob.Chunk{Start: start, End: end}, nil
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
} }

View File

@@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
) )
@@ -427,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
err := rc.Pull(ctx, "single") err := rc.Pull(ctx, "single")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{0, 6} want := []int64{6}
if !errors.Is(errors.Join(errs...), ErrCached) { if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached) t.Errorf("errs = %v; want %v", errs, ErrCached)
} }
@@ -530,6 +531,54 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
} }
} }
func TestRegistryPullChunking(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) { func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)

View File

@@ -0,0 +1,11 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -1,5 +1,6 @@
// Package registry implements an http.Handler for handling local Ollama API // Package registry provides an http.Handler for handling local Ollama API
// model management requests. See [Local] for details. // requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
package registry package registry
import ( import (
@@ -9,7 +10,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -18,11 +18,16 @@ import (
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
) )
// Local implements an http.Handler for handling local Ollama API model // Local is an http.Handler for handling local Ollama API requests for
// management requests, such as pushing, pulling, and deleting models. // performing tasks related to the ollama.com model registry combined with the
// local disk cache.
// //
// It can be arranged for all unknown requests to be passed through to a // It is not concern of Local, or this package, to handle model creation, which
// fallback handler, if one is provided. // proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
type Local struct { type Local struct {
Client *ollama.Registry // required Client *ollama.Registry // required
Logger *slog.Logger // required Logger *slog.Logger // required
@@ -58,7 +63,6 @@ func (e serverError) Error() string {
var ( var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"} errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"} errNotFound = &serverError{404, "not_found", "not found"}
errModelNotFound = &serverError{404, "not_found", "model not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"} errInternalError = &serverError{500, "internal_error", "internal server error"}
) )
@@ -171,16 +175,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
} }
type params struct { type params struct {
// DeprecatedName is the name of the model to push, pull, or delete, DeprecatedName string `json:"name"` // Use [params.model]
// but is deprecated. New clients should use [Model] instead. Model string `json:"model"` // Use [params.model]
//
// Use [model()] to get the model name for both old and new API requests.
DeprecatedName string `json:"name"`
// Model is the name of the model to push, pull, or delete.
//
// Use [model()] to get the model name for both old and new API requests.
Model string `json:"model"`
// AllowNonTLS is a flag that indicates a client using HTTP // AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately. // is doing so, deliberately.
@@ -193,18 +189,9 @@ type params struct {
// confusing flags such as this. // confusing flags such as this.
AllowNonTLS bool `json:"insecure"` AllowNonTLS bool `json:"insecure"`
// Stream, if true, will make the server send progress updates in a // ProgressStream is a flag that indicates the client is expecting a stream of
// streaming of JSON objects. If false, the server will send a single // progress updates.
// JSON object with the final status as "success", or an error object ProgressStream bool `json:"stream"`
// if an error occurred.
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
Stream *bool `json:"stream"`
} }
// model returns the model name for both old and new API requests. // model returns the model name for both old and new API requests.
@@ -212,13 +199,6 @@ func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName) return cmp.Or(p.Model, p.DeprecatedName)
} }
func (p params) stream() bool {
if p.Stream == nil {
return true
}
return *p.Stream
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" { if r.Method != "DELETE" {
return errMethodNotAllowed return errMethodNotAllowed
@@ -232,16 +212,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return err return err
} }
if !ok { if !ok {
return errModelNotFound return &serverError{404, "not_found", "model not found"}
} }
if s.Prune != nil { if s.Prune == nil {
return s.Prune() return nil
} }
return nil return s.Prune()
} }
type progressUpdateJSON struct { type progressUpdateJSON struct {
Status string `json:"status,omitempty,omitzero"` Status string `json:"status"`
Digest blob.Digest `json:"digest,omitempty,omitzero"` Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"` Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"` Completed int64 `json:"completed,omitempty,omitzero"`
@@ -257,17 +237,6 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
enc := json.NewEncoder(w)
if !p.stream() {
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return errModelNotFound
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
}
maybeFlush := func() { maybeFlush := func() {
fl, _ := w.(http.Flusher) fl, _ := w.(http.Flusher)
if fl != nil { if fl != nil {
@@ -277,67 +246,69 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
defer maybeFlush() defer maybeFlush()
var mu sync.Mutex var mu sync.Mutex
progress := make(map[*ollama.Layer]int64) enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
progressCopy := make(map[*ollama.Layer]int64, len(progress)) ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
pushUpdate := func() { Update: func(l *ollama.Layer, n int64, err error) {
defer maybeFlush() mu.Lock()
defer mu.Unlock()
// TODO(bmizerany): This scales poorly with more layers due to // TODO(bmizerany): coalesce these updates; writing per
// needing to flush out them all in one big update. We _could_ // update is expensive
// just flush on the changed ones, or just track the whole
// download. Needs more thought. This is fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progress {
enc.Encode(progressUpdateJSON{ enc.Encode(progressUpdateJSON{
Digest: l.Digest, Digest: l.Digest,
Status: "pulling",
Total: l.Size, Total: l.Size,
Completed: n, Completed: n,
}) })
}
}
t := time.NewTicker(time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
pushUpdate()
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
start() // flush initial state
}
mu.Lock()
progress[l] = n
mu.Unlock()
}, },
}) })
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model()) done <- s.Client.Pull(ctx, p.model())
}() }()
for { func() {
select { t := time.NewTicker(100 * time.Millisecond)
case <-t.C: defer t.Stop()
pushUpdate() for {
case err := <-done: select {
pushUpdate() case <-t.C:
if err != nil { mu.Lock()
var status string maybeFlush()
if errors.Is(err, ollama.ErrModelNotFound) { mu.Unlock()
status = fmt.Sprintf("error: model %q not found", p.model()) case err := <-done:
} else { if err != nil {
status = fmt.Sprintf("error: %v", err) var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
} }
enc.Encode(progressUpdateJSON{Status: status})
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
} }
return nil
} }
} }()
return nil
} }
func decodeUserJSON[T any](r io.Reader) (T, error) { func decodeUserJSON[T any](r io.Reader) (T, error) {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"io/fs" "io/fs"
"net" "net"
@@ -159,6 +160,7 @@ var registryFS = sync.OnceValue(func() fs.FS {
// to \n when parsing the txtar on Windows. // to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n")) data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data) a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a) fsys, err := txtar.FS(a)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -177,7 +179,7 @@ func TestServerPull(t *testing.T) {
w.WriteHeader(404) w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`) io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default: default:
t.Logf("serving blob: %s", r.URL.Path) t.Logf("serving file: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r) modelsHandler.ServeHTTP(w, r)
} }
}) })
@@ -186,7 +188,7 @@ func TestServerPull(t *testing.T) {
t.Helper() t.Helper()
if got.Code != 200 { if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code) t.Fatalf("Code = %d; want 200", got.Code)
} }
gotlines := got.Body.String() gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines) t.Logf("got:\n%s", gotlines)
@@ -195,29 +197,35 @@ func TestServerPull(t *testing.T) {
want, unwanted := strings.CutPrefix(want, "!") want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) { if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("! missing %q in body", want) t.Fatalf("! missing %q in body", want)
} }
if unwanted && strings.Contains(gotlines, want) { if unwanted && strings.Contains(gotlines, want) {
t.Errorf("! unexpected %q in body", want) t.Fatalf("! unexpected %q in body", want)
} }
} }
} }
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, ` checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} {"status":"pulling manifest"}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"} {"status":"error: model \"unknown\" not found"}
`) `)
@@ -232,39 +240,19 @@ func TestServerPull(t *testing.T) {
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`) got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""} {"status":"error: invalid or missing name: \"\""}
`)
// Non-streaming pulls !verifying
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`) !writing
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") !success
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
checkResponse(got, `
{"status":"success"}
!digest
!total
!completed
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
checkErrorResponse(t, got, 404, "not_found", "model not found")
} }
func TestServerUnknownPath(t *testing.T) { func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil) s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`) got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found") checkErrorResponse(t, got, 404, "not_found", "not found")
var fellback bool
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fellback = true
})
got = s.send(t, "DELETE", "/api/unknown", `{}`)
if !fellback {
t.Fatal("expected Fallback to be called")
}
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
} }
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) { func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {