server/internal/client/ollama: cleanup use of multiple counters (#10304)
The completed and received counters must work in tandem and the code should better reflect that. Previously, the act of updating them was 2-3 lines of code duplicated in multiple places. This consolidates them into a single update closure for easy reading and maintenance. This also simplifies error handling in places where we can use a return parameter and defer to handle the error case for updates. Also, remove the old Layer field from the trackingReader struct.
This commit is contained in:
parent
ed4e139314
commit
3457a315b2
@ -431,14 +431,13 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
//
|
||||
// It always calls update with a nil error.
|
||||
type trackingReader struct {
|
||||
l *Layer
|
||||
r io.Reader
|
||||
update func(n int64)
|
||||
update func(n int64, err error) // err is always nil
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.update(int64(n))
|
||||
r.update(int64(n), nil)
|
||||
return
|
||||
}
|
||||
|
||||
@ -483,26 +482,34 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
expected += l.Size
|
||||
}
|
||||
|
||||
var completed atomic.Int64
|
||||
var g errgroup.Group
|
||||
g.SetLimit(r.maxStreams())
|
||||
|
||||
var completed atomic.Int64
|
||||
for _, l := range layers {
|
||||
var received atomic.Int64
|
||||
update := func(n int64, err error) {
|
||||
completed.Add(n)
|
||||
t.update(l, received.Add(n), err)
|
||||
}
|
||||
|
||||
info, err := c.Get(l.Digest)
|
||||
if err == nil && info.Size == l.Size {
|
||||
received.Add(l.Size)
|
||||
completed.Add(l.Size)
|
||||
t.update(l, l.Size, ErrCached)
|
||||
update(l.Size, ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
func() {
|
||||
func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
update(0, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
t.update(l, received.Load(), err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// Close the chunked writer when all chunks are
|
||||
@ -522,11 +529,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Chunksum stream interrupted. Note in trace
|
||||
// log and let in-flight downloads complete.
|
||||
// This will naturally trigger ErrIncomplete
|
||||
// since received < expected bytes.
|
||||
t.update(l, received.Load(), err)
|
||||
// Note the chunksum stream
|
||||
// interuption, but do not cancel
|
||||
// in-flight downloads. We can still
|
||||
// make progress on them. Once they are
|
||||
// done, ErrIncomplete will be returned
|
||||
// below.
|
||||
update(0, err)
|
||||
break
|
||||
}
|
||||
|
||||
@ -540,31 +549,17 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
recv := received.Add(cs.Chunk.Size())
|
||||
completed.Add(cs.Chunk.Size())
|
||||
t.update(l, recv, ErrCached)
|
||||
update(cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
} else {
|
||||
t.update(l, received.Load(), err)
|
||||
defer wg.Done()
|
||||
if err != nil {
|
||||
update(0, err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
@ -579,17 +574,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
defer res.Body.Close()
|
||||
|
||||
tr := &trackingReader{
|
||||
l: l,
|
||||
r: res.Body,
|
||||
update: func(n int64) {
|
||||
completed.Add(n)
|
||||
recv := received.Add(n)
|
||||
t.update(l, recv, nil)
|
||||
},
|
||||
r: res.Body,
|
||||
update: update,
|
||||
}
|
||||
return chunked.Put(cs.Chunk, cs.Digest, tr)
|
||||
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Record the downloading of this chunk.
|
||||
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user