diff --git a/cmd/cmd.go b/cmd/cmd.go index fef7242b7..79ff87ac8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error { fn := func(resp api.ProgressResponse) error { if resp.Digest != "" { + if resp.Completed == 0 { + // This is the initial status update for the + // layer, which the server sends before + // beginning the download, for clients to + // compute total size and prepare for + // downloads, if needed. + // + // Skipping this here to avoid showing a 0% + // progress bar, which *should* clue the user + // into the fact that many things are being + // downloaded and that the current active + // download is not that last. However, in rare + // cases it seems to be triggering to some, and + // it isn't worth explaining, so just ignore + // and regress to the old UI that keeps giving + // you the "But wait, there is more!" after + // each "100% done" bar, which is "better." + return nil + } + if spinner != nil { spinner.Stop() } bar, ok := bars[resp.Digest] if !ok { - bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) + name, isDigest := strings.CutPrefix(resp.Digest, "sha256:") + name = strings.TrimSpace(name) + if isDigest { + name = name[:min(12, len(name))] + } + bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed) bars[resp.Digest] = bar p.Add(resp.Digest, bar) } @@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { } request := api.PullRequest{Name: args[0], Insecure: insecure} - if err := client.Pull(cmd.Context(), &request, fn); err != nil { - return err - } - - return nil + return client.Pull(cmd.Context(), &request, fn) } type generateContextKey string diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 409932bfd..3eb3c5c24 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) { // // In both cases, the code field is optional and may be empty. type Error struct { - Status int `json:"-"` // TODO(bmizerany): remove this + status int `json:"-"` // TODO(bmizerany): remove this Code string `json:"code"` Message string `json:"message"` } +// Temporary reports if the error is temporary (e.g. 5xx status code). +func (e *Error) Temporary() bool { + return e.status >= 500 +} + func (e *Error) Error() string { var b strings.Builder b.WriteString("registry responded with status ") - b.WriteString(strconv.Itoa(e.Status)) + b.WriteString(strconv.Itoa(e.status)) if e.Code != "" { b.WriteString(": code ") b.WriteString(e.Code) @@ -129,7 +134,7 @@ func (e *Error) Error() string { func (e *Error) LogValue() slog.Value { return slog.GroupValue( - slog.Int("status", e.Status), + slog.Int("status", e.status), slog.String("code", e.Code), slog.String("message", e.Message), ) @@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { type trackingReader struct { l *Layer r io.Reader - update func(l *Layer, n int64, err error) + update func(n int64) } func (r *trackingReader) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) - r.update(r.l, int64(n), nil) + r.update(int64(n)) return } @@ -478,111 +483,120 @@ func (r *Registry) Pull(ctx context.Context, name string) error { expected += l.Size } - var received atomic.Int64 + var completed atomic.Int64 var g errgroup.Group g.SetLimit(r.maxStreams()) for _, l := range layers { + var received atomic.Int64 + info, err := c.Get(l.Digest) if err == nil && info.Size == l.Size { received.Add(l.Size) + completed.Add(l.Size) t.update(l, l.Size, ErrCached) continue } - var wg sync.WaitGroup - chunked, err := c.Chunked(l.Digest, l.Size) - if err != nil { - t.update(l, 0, err) - continue - } - - for cs, err := range r.chunksums(ctx, name, l) { + func() { + var wg sync.WaitGroup + chunked, err := c.Chunked(l.Digest, l.Size) if err != nil { - // Chunksum stream interrupted. Note in trace - // log and let in-flight downloads complete. - // This will naturally trigger ErrIncomplete - // since received < expected bytes. - t.update(l, 0, err) - break + t.update(l, received.Load(), err) + return } + defer func() { + // Close the chunked writer when all chunks are + // downloaded. + // + // This is done as a background task in the + // group to allow the next layer to start while + // we wait for the final chunk in this layer to + // complete. It also ensures this is done + // before we exit Pull. + g.Go(func() error { + wg.Wait() + chunked.Close() + return nil + }) + }() - cacheKey := fmt.Sprintf( - "v1 pull chunksum %s %s %d-%d", - l.Digest, - cs.Digest, - cs.Chunk.Start, - cs.Chunk.End, - ) - cacheKeyDigest := blob.DigestFromBytes(cacheKey) - _, err := c.Get(cacheKeyDigest) - if err == nil { - received.Add(cs.Chunk.Size()) - t.update(l, cs.Chunk.Size(), ErrCached) - continue - } + for cs, err := range r.chunksums(ctx, name, l) { + if err != nil { + // Chunksum stream interrupted. Note in trace + // log and let in-flight downloads complete. + // This will naturally trigger ErrIncomplete + // since received < expected bytes. + t.update(l, received.Load(), err) + break + } - wg.Add(1) - g.Go(func() (err error) { - defer func() { - if err == nil { - // Ignore cache key write errors for now. We've already - // reported to trace that the chunk is complete. - // - // Ideally, we should only report completion to trace - // after successful cache commit. This current approach - // works but could trigger unnecessary redownloads if - // the checkpoint key is missing on next pull. - // - // Not incorrect, just suboptimal - fix this in a - // future update. - _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + cacheKey := fmt.Sprintf( + "v1 pull chunksum %s %s %d-%d", + l.Digest, + cs.Digest, + cs.Chunk.Start, + cs.Chunk.End, + ) + cacheKeyDigest := blob.DigestFromBytes(cacheKey) + _, err := c.Get(cacheKeyDigest) + if err == nil { + recv := received.Add(cs.Chunk.Size()) + completed.Add(cs.Chunk.Size()) + t.update(l, recv, ErrCached) + continue + } - received.Add(cs.Chunk.Size()) - } else { - t.update(l, 0, err) + wg.Add(1) + g.Go(func() (err error) { + defer func() { + if err == nil { + // Ignore cache key write errors for now. We've already + // reported to trace that the chunk is complete. + // + // Ideally, we should only report completion to trace + // after successful cache commit. This current approach + // works but could trigger unnecessary redownloads if + // the checkpoint key is missing on next pull. + // + // Not incorrect, just suboptimal - fix this in a + // future update. + _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + } else { + t.update(l, received.Load(), err) + } + wg.Done() + }() + + req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) + if err != nil { + return err } - wg.Done() - }() + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) + res, err := sendRequest(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() - req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) - if err != nil { - return err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) - res, err := sendRequest(r.client(), req) - if err != nil { - return err - } - defer res.Body.Close() - - body := &trackingReader{l: l, r: res.Body, update: t.update} - return chunked.Put(cs.Chunk, cs.Digest, body) - }) - } - - // Close writer immediately after downloads finish, not at Pull - // exit. Using defer would keep file descriptors open until all - // layers complete, potentially exhausting system limits with - // many layers. - // - // The WaitGroup tracks when all chunks finish downloading, - // allowing precise writer closure in a background goroutine. - // Each layer briefly uses one extra goroutine while at most - // maxStreams()-1 chunks download in parallel. - // - // This caps file descriptors at maxStreams() instead of - // growing with layer count. - g.Go(func() error { - wg.Wait() - chunked.Close() - return nil - }) + tr := &trackingReader{ + l: l, + r: res.Body, + update: func(n int64) { + completed.Add(n) + recv := received.Add(n) + t.update(l, recv, nil) + }, + } + return chunked.Put(cs.Chunk, cs.Digest, tr) + }) + } + }() } if err := g.Wait(); err != nil { return err } - if received.Load() != expected { - return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected) + if recv := completed.Load(); recv != expected { + return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, recv, expected) } md := blob.DigestFromBytes(m.Data) @@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) return nil, ErrModelNotFound } - re.Status = res.StatusCode + re.status = res.StatusCode return nil, &re } return res, nil diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 80d39b765..8a3107356 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) { func checkErrCode(t *testing.T, err error, status int, code string) { t.Helper() var e *Error - if !errors.As(err, &e) || e.Status != status || e.Code != code { + if !errors.As(err, &e) || e.status != status || e.Code != code { t.Errorf("err = %v; want %v %v", err, status, code) } } @@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) { // now send the second chunksum and ensure it kicks off work immediately fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c")) - if g := <-update; g != 1 { - t.Fatalf("got %d, want 1", g) + if g := <-update; g != 3 { + t.Fatalf("got %d, want 3", g) } csw.Close() testutil.Check(t, <-errc) @@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) { _, err = c.Cache.Resolve("o.com/library/abc:latest") check(err) - if g := written.Load(); g != 3 { + if g := written.Load(); g != 5 { t.Fatalf("wrote %d bytes, want 3", g) } if g := cached.Load(); g != 2 { // "ab" should have been cached - t.Fatalf("cached %d bytes, want 3", g) + t.Fatalf("cached %d bytes, want 5", g) } } diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index 69435c406..a7cac0d5d 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) { type traceKey struct{} -// WithTrace returns a context derived from ctx that uses t to report trace -// events. +// WithTrace adds a trace to the context for transfer progress reporting. func WithTrace(ctx context.Context, t *Trace) context.Context { - return context.WithValue(ctx, traceKey{}, t) + old := traceFromContext(ctx) + if old == t { + // No change, return the original context. This also prevents + // infinite recursion below, if the caller passes the same + // Trace. + return ctx + } + + // Create a new Trace that wraps the old one, if any. If we used the + // same pointer t, we end up with a recursive structure. + composed := &Trace{ + Update: func(l *Layer, n int64, err error) { + if old != nil { + old.update(l, n, err) + } + t.update(l, n, err) + }, + } + return context.WithValue(ctx, traceKey{}, composed) } var emptyTrace = &Trace{} diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 1910b1877..4790c80d9 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -9,13 +9,14 @@ import ( "fmt" "io" "log/slog" - "maps" "net/http" + "slices" "sync" "time" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" + "github.com/ollama/ollama/server/internal/internal/backoff" ) // Local implements an http.Handler for handling local Ollama API model @@ -265,68 +266,81 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { } return err } - return enc.Encode(progressUpdateJSON{Status: "success"}) + enc.Encode(progressUpdateJSON{Status: "success"}) + return nil } - maybeFlush := func() { + var mu sync.Mutex + var progress []progressUpdateJSON + flushProgress := func() { + mu.Lock() + progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire + mu.Unlock() + for _, p := range progress { + enc.Encode(p) + } fl, _ := w.(http.Flusher) if fl != nil { fl.Flush() } } - defer maybeFlush() - - var mu sync.Mutex - progress := make(map[*ollama.Layer]int64) - - progressCopy := make(map[*ollama.Layer]int64, len(progress)) - flushProgress := func() { - defer maybeFlush() - - // TODO(bmizerany): Flushing every layer in one update doesn't - // scale well. We could flush only the modified layers or track - // the full download. Needs further consideration, though it's - // fine for now. - mu.Lock() - maps.Copy(progressCopy, progress) - mu.Unlock() - for l, n := range progressCopy { - enc.Encode(progressUpdateJSON{ - Digest: l.Digest, - Total: l.Size, - Completed: n, - }) - } - } defer flushProgress() - t := time.NewTicker(1000 * time.Hour) // "unstarted" timer + t := time.NewTicker(1<<63 - 1) // "unstarted" timer start := sync.OnceFunc(func() { flushProgress() // flush initial state t.Reset(100 * time.Millisecond) }) ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ Update: func(l *ollama.Layer, n int64, err error) { - if n > 0 { - // Block flushing progress updates until every - // layer is accounted for. Clients depend on a - // complete model size to calculate progress - // correctly; if they use an incomplete total, - // progress indicators would erratically jump - // as new layers are registered. - start() + if err != nil && !errors.Is(err, ollama.ErrCached) { + s.Logger.Error("pulling", "model", p.model(), "error", err) + return } - mu.Lock() - progress[l] += n - mu.Unlock() + + func() { + mu.Lock() + defer mu.Unlock() + for i, p := range progress { + if p.Digest == l.Digest { + progress[i].Completed = n + return + } + } + progress = append(progress, progressUpdateJSON{ + Digest: l.Digest, + Total: l.Size, + }) + }() + + // Block flushing progress updates until every + // layer is accounted for. Clients depend on a + // complete model size to calculate progress + // correctly; if they use an incomplete total, + // progress indicators would erratically jump + // as new layers are registered. + start() }, }) done := make(chan error, 1) - go func() { - done <- s.Client.Pull(ctx, p.model()) + go func() (err error) { + defer func() { done <- err }() + for _, err := range backoff.Loop(ctx, 3*time.Second) { + if err != nil { + return err + } + err := s.Client.Pull(ctx, p.model()) + var oe *ollama.Error + if errors.As(err, &oe) && oe.Temporary() { + continue // retry + } + return err + } + return nil }() + enc.Encode(progressUpdateJSON{Status: "pulling manifest"}) for { select { case <-t.C: @@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { status = fmt.Sprintf("error: %v", err) } enc.Encode(progressUpdateJSON{Status: status}) + return nil } + + // Emulate old client pull progress (for now): + enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"}) + enc.Encode(progressUpdateJSON{Status: "writing manifest"}) + enc.Encode(progressUpdateJSON{Status: "success"}) return nil } } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 3f20e518a..61b57f114 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder { t.Helper() - req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body)) + ctx := ollama.WithTrace(t.Context(), &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + t.Logf("update: %s %d %v", l.Digest, n, err) + }, + }) + req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body)) return s.sendRequest(t, req) } @@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) { checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { t.Helper() - if got.Code != 200 { t.Errorf("Code = %d; want 200", got.Code) } gotlines := got.Body.String() + if strings.TrimSpace(gotlines) == "" { + gotlines = "" + } t.Logf("got:\n%s", gotlines) for want := range strings.Lines(wantlines) { want = strings.TrimSpace(want) want, unwanted := strings.CutPrefix(want, "!") want = strings.TrimSpace(want) if !unwanted && !strings.Contains(gotlines, want) { - t.Errorf("! missing %q in body", want) + t.Errorf("\t! missing %q in body", want) } if unwanted && strings.Contains(gotlines, want) { - t.Errorf("! unexpected %q in body", want) + t.Errorf("\t! unexpected %q in body", want) } } } - got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) + got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`) checkResponse(got, ` - {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} - `) - - got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) - checkResponse(got, ` - {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} - {"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} + {"status":"pulling manifest"} {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} - {"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} + {"status":"verifying sha256 digest"} + {"status":"writing manifest"} + {"status":"success"} `) got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)