diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index fdac71bbf..590418674 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -59,6 +59,11 @@ var ( // ErrCached is passed to [Trace.PushUpdate] when a layer already // exists. It is a non-fatal error and is never returned by [Registry.Push]. ErrCached = errors.New("cached") + + // ErrIncomplete is returned by [Registry.Pull] when a model pull was + // incomplete due to one or more layer download failures. Users that + // want specific errors should use [WithTrace]. + ErrIncomplete = errors.New("incomplete") ) // Defaults @@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) { func UserAgent() string { buildinfo, _ := debug.ReadBuildInfo() + + version := buildinfo.Main.Version + if version == "(devel)" { + // When using `go run .` the version is "(devel)". This is seen + // as an invalid version by ollama.com and so it defaults to + // "needs upgrade" for some requests, such as pulls. These + // checks can be skipped by using the special version "v0.0.0", + // so we set it to that here. + version = "v0.0.0" + } + return fmt.Sprintf("ollama/%s (%s %s) Go/%s", - buildinfo.Main.Version, + version, runtime.GOARCH, runtime.GOOS, runtime.Version(), @@ -418,13 +434,14 @@ func canRetry(err error) bool { // // It always calls update with a nil error. type trackingReader struct { - r io.Reader - n *atomic.Int64 + l *Layer + r io.Reader + update func(l *Layer, n int64, err error) } func (r *trackingReader) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) - r.n.Add(int64(n)) + r.update(r.l, int64(n), nil) return } @@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // Send initial layer trace events to allow clients to have an // understanding of work to be done before work starts. + var expected int64 t := traceFromContext(ctx) for _, l := range layers { t.update(l, 0, nil) + expected += l.Size } + var total atomic.Int64 var g errgroup.Group g.SetLimit(r.maxStreams()) for _, l := range layers { info, err := c.Get(l.Digest) if err == nil && info.Size == l.Size { + total.Add(l.Size) t.update(l, l.Size, ErrCached) continue } @@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // TODO(bmizerany): fix this unbounded use of defer defer chunked.Close() - var progress atomic.Int64 for cs, err := range r.chunksums(ctx, name, l) { if err != nil { - // Bad chunksums response, update tracing - // clients and then bail. - t.update(l, progress.Load(), err) - return err + // Chunksum stream was interrupted, so tell + // trace about it, and let in-flight chunk + // downloads finish. Once they finish, return + // ErrIncomplete, which is triggered by the + // fact that the total bytes received is less + // than the expected bytes. + t.update(l, 0, err) + break } g.Go(func() (err error) { defer func() { - if err != nil { + if err == nil || errors.Is(err, ErrCached) { + total.Add(cs.Chunk.Size()) + } else { err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) } - t.update(l, progress.Load(), err) }() req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) @@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // download rate since it knows better than a // client that is measuring rate based on // wall-clock time-since-last-update. - body := &trackingReader{r: res.Body, n: &progress} + body := &trackingReader{l: l, r: res.Body, update: t.update} return chunked.Put(cs.Chunk, cs.Digest, body) }) @@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if err := g.Wait(); err != nil { return err } + if total.Load() != expected { + return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected) + } md := blob.DigestFromBytes(m.Data) if err := blob.PutBytes(c, md, m.Data); err != nil { @@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se } blobURL := res.Header.Get("Content-Location") - var size int64 s := bufio.NewScanner(res.Body) s.Split(bufio.ScanWords) for { if !s.Scan() { if s.Err() != nil { yield(chunksum{}, s.Err()) - } else if size != l.Size { - yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size)) } return } @@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se return } - size += chunk.Size() - if size > l.Size { - yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size)) - return - } - cs := chunksum{ URL: blobURL, Chunk: chunk, diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 305295435..f8136c06f 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -25,6 +25,28 @@ import ( "github.com/ollama/ollama/server/internal/testutil" ) +func ExampleRegistry_cancelOnFirstError() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx = WithTrace(ctx, &Trace{ + Update: func(l *Layer, n int64, err error) { + if err != nil { + // Discontinue pulling layers if there is an + // error instead of continuing to pull more + // data. + cancel() + } + }, + }) + + var r Registry + if err := r.Pull(ctx, "model"); err != nil { + // panic for demo purposes + panic(err) + } +} + func TestManifestMarshalJSON(t *testing.T) { // All manifests should contain an "empty" config object. var m Manifest @@ -813,8 +835,13 @@ func TestPullChunksums(t *testing.T) { ) err := rc.Pull(ctx, "test") check(err) - if !slices.Equal(reads, []int64{0, 3, 5}) { - t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5}) + wantReads := []int64{ + 0, // initial signaling of layer pull starting + 3, // first chunk read + 2, // second chunk read + } + if !slices.Equal(reads, wantReads) { + t.Errorf("reads = %v; want %v", reads, wantReads) } mw, err := rc.Resolve(t.Context(), "test") diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 2a935b525..1910b1877 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -200,7 +200,7 @@ type params struct { // // Unfortunately, this API was designed to be a bit awkward. Stream is // defined to default to true if not present, so we need a way to check - // if the client decisively it to false. So, we use a pointer to a + // if the client decisively set it to false. So, we use a pointer to a // bool. Gross. // // Use [stream()] to get the correct value for this field. @@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { progress := make(map[*ollama.Layer]int64) progressCopy := make(map[*ollama.Layer]int64, len(progress)) - pushUpdate := func() { + flushProgress := func() { defer maybeFlush() - // TODO(bmizerany): This scales poorly with more layers due to - // needing to flush out them all in one big update. We _could_ - // just flush on the changed ones, or just track the whole - // download. Needs more thought. This is fine for now. + // TODO(bmizerany): Flushing every layer in one update doesn't + // scale well. We could flush only the modified layers or track + // the full download. Needs further consideration, though it's + // fine for now. mu.Lock() maps.Copy(progressCopy, progress) mu.Unlock() - for l, n := range progress { + for l, n := range progressCopy { enc.Encode(progressUpdateJSON{ Digest: l.Digest, Total: l.Size, @@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { }) } } + defer flushProgress() - t := time.NewTicker(time.Hour) // "unstarted" timer + t := time.NewTicker(1000 * time.Hour) // "unstarted" timer start := sync.OnceFunc(func() { - pushUpdate() + flushProgress() // flush initial state t.Reset(100 * time.Millisecond) }) ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ Update: func(l *ollama.Layer, n int64, err error) { if n > 0 { - start() // flush initial state + // Block flushing progress updates until every + // layer is accounted for. Clients depend on a + // complete model size to calculate progress + // correctly; if they use an incomplete total, + // progress indicators would erratically jump + // as new layers are registered. + start() } mu.Lock() - progress[l] = n + progress[l] += n mu.Unlock() }, }) @@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { for { select { case <-t.C: - pushUpdate() + flushProgress() case err := <-done: - pushUpdate() + flushProgress() if err != nil { var status string if errors.Is(err, ollama.ErrModelNotFound) {