server/internal/client/ollama: persist through chunk download errors (#9923)

This commit is contained in:
Blake Mizerany 2025-03-21 13:03:43 -07:00 committed by GitHub
parent 00ebda8cc4
commit c794fef2f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 36 deletions

View File

@ -59,6 +59,11 @@ var (
// ErrCached is passed to [Trace.PushUpdate] when a layer already // ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push]. // exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached") ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
) )
// Defaults // Defaults
@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
func UserAgent() string { func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo() buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s", return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
buildinfo.Main.Version, version,
runtime.GOARCH, runtime.GOARCH,
runtime.GOOS, runtime.GOOS,
runtime.Version(), runtime.Version(),
@ -418,13 +434,14 @@ func canRetry(err error) bool {
// //
// It always calls update with a nil error. // It always calls update with a nil error.
type trackingReader struct { type trackingReader struct {
r io.Reader l *Layer
n *atomic.Int64 r io.Reader
update func(l *Layer, n int64, err error)
} }
func (r *trackingReader) Read(p []byte) (n int, err error) { func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p) n, err = r.r.Read(p)
r.n.Add(int64(n)) r.update(r.l, int64(n), nil)
return return
} }
@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an // Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts. // understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx) t := traceFromContext(ctx)
for _, l := range layers { for _, l := range layers {
t.update(l, 0, nil) t.update(l, 0, nil)
expected += l.Size
} }
var total atomic.Int64
var g errgroup.Group var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for _, l := range layers { for _, l := range layers {
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size { if err == nil && info.Size == l.Size {
total.Add(l.Size)
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
continue continue
} }
@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// TODO(bmizerany): fix this unbounded use of defer // TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close() defer chunked.Close()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
// Bad chunksums response, update tracing // Chunksum stream was interrupted, so tell
// clients and then bail. // trace about it, and let in-flight chunk
t.update(l, progress.Load(), err) // downloads finish. Once they finish, return
return err // ErrIncomplete, which is triggered by the
// fact that the total bytes received is less
// than the expected bytes.
t.update(l, 0, err)
break
} }
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { defer func() {
if err != nil { if err == nil || errors.Is(err, ErrCached) {
total.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
} }
t.update(l, progress.Load(), err)
}() }()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// download rate since it knows better than a // download rate since it knows better than a
// client that is measuring rate based on // client that is measuring rate based on
// wall-clock time-since-last-update. // wall-clock time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress} body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body) return chunked.Put(cs.Chunk, cs.Digest, body)
}) })
@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
if total.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
}
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil { if err := blob.PutBytes(c, md, m.Data); err != nil {
@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
} }
blobURL := res.Header.Get("Content-Location") blobURL := res.Header.Get("Content-Location")
var size int64
s := bufio.NewScanner(res.Body) s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords) s.Split(bufio.ScanWords)
for { for {
if !s.Scan() { if !s.Scan() {
if s.Err() != nil { if s.Err() != nil {
yield(chunksum{}, s.Err()) yield(chunksum{}, s.Err())
} else if size != l.Size {
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
} }
return return
} }
@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return return
} }
size += chunk.Size()
if size > l.Size {
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
return
}
cs := chunksum{ cs := chunksum{
URL: blobURL, URL: blobURL,
Chunk: chunk, Chunk: chunk,

View File

@ -25,6 +25,28 @@ import (
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
) )
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) { func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object. // All manifests should contain an "empty" config object.
var m Manifest var m Manifest
@ -813,8 +835,13 @@ func TestPullChunksums(t *testing.T) {
) )
err := rc.Pull(ctx, "test") err := rc.Pull(ctx, "test")
check(err) check(err)
if !slices.Equal(reads, []int64{0, 3, 5}) { wantReads := []int64{
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5}) 0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
} }
mw, err := rc.Resolve(t.Context(), "test") mw, err := rc.Resolve(t.Context(), "test")

View File

@ -200,7 +200,7 @@ type params struct {
// //
// Unfortunately, this API was designed to be a bit awkward. Stream is // Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check // defined to default to true if not present, so we need a way to check
// if the client decisively it to false. So, we use a pointer to a // if the client decisively set it to false. So, we use a pointer to a
// bool. Gross. // bool. Gross.
// //
// Use [stream()] to get the correct value for this field. // Use [stream()] to get the correct value for this field.
@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
progress := make(map[*ollama.Layer]int64) progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress)) progressCopy := make(map[*ollama.Layer]int64, len(progress))
pushUpdate := func() { flushProgress := func() {
defer maybeFlush() defer maybeFlush()
// TODO(bmizerany): This scales poorly with more layers due to // TODO(bmizerany): Flushing every layer in one update doesn't
// needing to flush out them all in one big update. We _could_ // scale well. We could flush only the modified layers or track
// just flush on the changed ones, or just track the whole // the full download. Needs further consideration, though it's
// download. Needs more thought. This is fine for now. // fine for now.
mu.Lock() mu.Lock()
maps.Copy(progressCopy, progress) maps.Copy(progressCopy, progress)
mu.Unlock() mu.Unlock()
for l, n := range progress { for l, n := range progressCopy {
enc.Encode(progressUpdateJSON{ enc.Encode(progressUpdateJSON{
Digest: l.Digest, Digest: l.Digest,
Total: l.Size, Total: l.Size,
@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}) })
} }
} }
defer flushProgress()
t := time.NewTicker(time.Hour) // "unstarted" timer t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() { start := sync.OnceFunc(func() {
pushUpdate() flushProgress() // flush initial state
t.Reset(100 * time.Millisecond) t.Reset(100 * time.Millisecond)
}) })
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) { Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 { if n > 0 {
start() // flush initial state // Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
} }
mu.Lock() mu.Lock()
progress[l] = n progress[l] += n
mu.Unlock() mu.Unlock()
}, },
}) })
@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
for { for {
select { select {
case <-t.C: case <-t.C:
pushUpdate() flushProgress()
case err := <-done: case err := <-done:
pushUpdate() flushProgress()
if err != nil { if err != nil {
var status string var status string
if errors.Is(err, ollama.ErrModelNotFound) { if errors.Is(err, ollama.ErrModelNotFound) {