From 3457a315b241d5d2ada9958d22cc5effb2643a7e Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 16 Apr 2025 14:33:40 -0700 Subject: [PATCH] server/internal/client/ollama: cleanup use of multiple counters (#10304) The completed and received counters must work in tandem and the code should better reflect that. Previously, the act of updating them was 2-3 lines of code duplicated in multiple places. This consolidates them into a single update closure for easy reading and maintenance. This also simplifies error handling in places where we can use a return parameter and defer to handle the error case for updates. Also, remove the old Layer field from the trackingReader struct. --- server/internal/client/ollama/registry.go | 79 +++++++++++------------ 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 3eb3c5c24..4d00b41e1 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -431,14 +431,13 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { // // It always calls update with a nil error. type trackingReader struct { - l *Layer r io.Reader - update func(n int64) + update func(n int64, err error) // err is always nil } func (r *trackingReader) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) - r.update(int64(n)) + r.update(int64(n), nil) return } @@ -483,26 +482,34 @@ func (r *Registry) Pull(ctx context.Context, name string) error { expected += l.Size } - var completed atomic.Int64 var g errgroup.Group g.SetLimit(r.maxStreams()) + + var completed atomic.Int64 for _, l := range layers { var received atomic.Int64 + update := func(n int64, err error) { + completed.Add(n) + t.update(l, received.Add(n), err) + } info, err := c.Get(l.Digest) if err == nil && info.Size == l.Size { - received.Add(l.Size) - completed.Add(l.Size) - t.update(l, l.Size, ErrCached) + update(l.Size, ErrCached) continue } - func() { + func() (err error) { + defer func() { + if err != nil { + update(0, err) + } + }() + var wg sync.WaitGroup chunked, err := c.Chunked(l.Digest, l.Size) if err != nil { - t.update(l, received.Load(), err) - return + return err } defer func() { // Close the chunked writer when all chunks are @@ -522,11 +529,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error { for cs, err := range r.chunksums(ctx, name, l) { if err != nil { - // Chunksum stream interrupted. Note in trace - // log and let in-flight downloads complete. - // This will naturally trigger ErrIncomplete - // since received < expected bytes. - t.update(l, received.Load(), err) + // Note the chunksum stream + // interuption, but do not cancel + // in-flight downloads. We can still + // make progress on them. Once they are + // done, ErrIncomplete will be returned + // below. + update(0, err) break } @@ -540,31 +549,17 @@ func (r *Registry) Pull(ctx context.Context, name string) error { cacheKeyDigest := blob.DigestFromBytes(cacheKey) _, err := c.Get(cacheKeyDigest) if err == nil { - recv := received.Add(cs.Chunk.Size()) - completed.Add(cs.Chunk.Size()) - t.update(l, recv, ErrCached) + update(cs.Chunk.Size(), ErrCached) continue } wg.Add(1) g.Go(func() (err error) { defer func() { - if err == nil { - // Ignore cache key write errors for now. We've already - // reported to trace that the chunk is complete. - // - // Ideally, we should only report completion to trace - // after successful cache commit. This current approach - // works but could trigger unnecessary redownloads if - // the checkpoint key is missing on next pull. - // - // Not incorrect, just suboptimal - fix this in a - // future update. - _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) - } else { - t.update(l, received.Load(), err) + defer wg.Done() + if err != nil { + update(0, err) } - wg.Done() }() req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) @@ -579,17 +574,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error { defer res.Body.Close() tr := &trackingReader{ - l: l, - r: res.Body, - update: func(n int64) { - completed.Add(n) - recv := received.Add(n) - t.update(l, recv, nil) - }, + r: res.Body, + update: update, } - return chunked.Put(cs.Chunk, cs.Digest, tr) + if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil { + return err + } + + // Record the downloading of this chunk. + return blob.PutBytes(c, cacheKeyDigest, cacheKey) }) } + + return nil }() } if err := g.Wait(); err != nil {