server/internal/client/ollama: cleanup use of multiple counters (#10304)

The completed and received counters must work in tandem and the code
should better reflect that. Previously, the act of updating them was 2-3
lines of code duplicated in multiple places. This consolidates them into
a single update closure for easy reading and maintenance.

This also simplifies error handling in places where we can use a return
parameter and defer to handle the error case for updates.

Also, remove the old Layer field from the trackingReader struct.
This commit is contained in:
Blake Mizerany 2025-04-16 14:33:40 -07:00 committed by GitHub
parent ed4e139314
commit 3457a315b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -431,14 +431,13 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
// //
// It always calls update with a nil error. // It always calls update with a nil error.
type trackingReader struct { type trackingReader struct {
l *Layer
r io.Reader r io.Reader
update func(n int64) update func(n int64, err error) // err is always nil
} }
func (r *trackingReader) Read(p []byte) (n int, err error) { func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p) n, err = r.r.Read(p)
r.update(int64(n)) r.update(int64(n), nil)
return return
} }
@ -483,26 +482,34 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
expected += l.Size expected += l.Size
} }
var completed atomic.Int64
var g errgroup.Group var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
var completed atomic.Int64
for _, l := range layers { for _, l := range layers {
var received atomic.Int64 var received atomic.Int64
update := func(n int64, err error) {
completed.Add(n)
t.update(l, received.Add(n), err)
}
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size { if err == nil && info.Size == l.Size {
received.Add(l.Size) update(l.Size, ErrCached)
completed.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue continue
} }
func() { func() (err error) {
defer func() {
if err != nil {
update(0, err)
}
}()
var wg sync.WaitGroup var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size) chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil { if err != nil {
t.update(l, received.Load(), err) return err
return
} }
defer func() { defer func() {
// Close the chunked writer when all chunks are // Close the chunked writer when all chunks are
@ -522,11 +529,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
// Chunksum stream interrupted. Note in trace // Note the chunksum stream
// log and let in-flight downloads complete. // interuption, but do not cancel
// This will naturally trigger ErrIncomplete // in-flight downloads. We can still
// since received < expected bytes. // make progress on them. Once they are
t.update(l, received.Load(), err) // done, ErrIncomplete will be returned
// below.
update(0, err)
break break
} }
@ -540,31 +549,17 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
cacheKeyDigest := blob.DigestFromBytes(cacheKey) cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest) _, err := c.Get(cacheKeyDigest)
if err == nil { if err == nil {
recv := received.Add(cs.Chunk.Size()) update(cs.Chunk.Size(), ErrCached)
completed.Add(cs.Chunk.Size())
t.update(l, recv, ErrCached)
continue continue
} }
wg.Add(1) wg.Add(1)
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { defer func() {
if err == nil { defer wg.Done()
// Ignore cache key write errors for now. We've already if err != nil {
// reported to trace that the chunk is complete. update(0, err)
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
} else {
t.update(l, received.Load(), err)
} }
wg.Done()
}() }()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
@ -579,17 +574,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
defer res.Body.Close() defer res.Body.Close()
tr := &trackingReader{ tr := &trackingReader{
l: l, r: res.Body,
r: res.Body, update: update,
update: func(n int64) {
completed.Add(n)
recv := received.Add(n)
t.update(l, recv, nil)
},
} }
return chunked.Put(cs.Chunk, cs.Digest, tr) if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
return err
}
// Record the downloading of this chunk.
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
}) })
} }
return nil
}() }()
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {