diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index d1d01ba46..fdac71bbf 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -37,7 +37,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/server/internal/cache/blob" - "github.com/ollama/ollama/server/internal/internal/backoff" "github.com/ollama/ollama/server/internal/internal/names" _ "embed" @@ -213,12 +212,6 @@ type Registry struct { // request. If zero, [DefaultChunkingThreshold] is used. ChunkingThreshold int64 - // MaxChunkSize is the maximum size of a chunk to download. If zero, - // the default is [DefaultMaxChunkSize]. - // - // It is only used when a layer is larger than [MaxChunkingThreshold]. - MaxChunkSize int64 - // Mask, if set, is the name used to convert non-fully qualified names // to fully qualified names. If empty, [DefaultMask] is used. Mask string @@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if err != nil { return err } + + // TODO(bmizerany): decide if this should be considered valid. Maybe + // server-side we special case '{}' to have some special meaning? Maybe + // "archiving" a tag (which is how we reason about it in the registry + // already, just with a different twist). if len(m.Layers) == 0 { return fmt.Errorf("%w: no layers", ErrManifestInvalid) } @@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { return err } - exists := func(l *Layer) bool { - info, err := c.Get(l.Digest) - return err == nil && info.Size == l.Size - } - + // TODO(bmizerany): work to remove the need to do this layers := m.Layers if m.Config != nil && m.Config.Digest.IsValid() { layers = append(layers, m.Config) @@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // Send initial layer trace events to allow clients to have an // understanding of work to be done before work starts. t := traceFromContext(ctx) - skip := make([]bool, len(layers)) - for i, l := range layers { + for _, l := range layers { t.update(l, 0, nil) - if exists(l) { - skip[i] = true - t.update(l, l.Size, ErrCached) - } } - g, ctx := errgroup.WithContext(ctx) + var g errgroup.Group g.SetLimit(r.maxStreams()) - for i, l := range layers { - if skip[i] { + for _, l := range layers { + info, err := c.Get(l.Digest) + if err == nil && info.Size == l.Size { + t.update(l, l.Size, ErrCached) continue } @@ -490,63 +481,50 @@ func (r *Registry) Pull(ctx context.Context, name string) error { t.update(l, 0, err) continue } + // TODO(bmizerany): fix this unbounded use of defer defer chunked.Close() var progress atomic.Int64 for cs, err := range r.chunksums(ctx, name, l) { if err != nil { + // Bad chunksums response, update tracing + // clients and then bail. t.update(l, progress.Load(), err) - break + return err } g.Go(func() (err error) { - defer func() { t.update(l, progress.Load(), err) }() - - for _, err := range backoff.Loop(ctx, 3*time.Second) { + defer func() { if err != nil { - return err + err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) } - err := func() error { - req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) - if err != nil { - return err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) - res, err := sendRequest(r.client(), req) - if err != nil { - return err - } - defer res.Body.Close() + t.update(l, progress.Load(), err) + }() - // Count bytes towards - // progress, as they arrive, so - // that our bytes piggyback - // other chunk updates on - // completion. - // - // This tactic is enough to - // show "smooth" progress given - // the current CLI client. In - // the near future, the server - // should report download rate - // since it knows better than - // a client that is measuring - // rate based on wall-clock - // time-since-last-update. - body := &trackingReader{r: res.Body, n: &progress} - - err = chunked.Put(cs.Chunk, cs.Digest, body) - if err != nil { - return err - } - - return nil - }() - if !canRetry(err) { - return err - } + req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) + if err != nil { + return err } - return nil + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) + res, err := sendRequest(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + + // Count bytes towards progress, as they + // arrive, so that our bytes piggyback other + // chunk updates on completion. + // + // This tactic is enough to show "smooth" + // progress given the current CLI client. In + // the near future, the server should report + // download rate since it knows better than a + // client that is measuring rate based on + // wall-clock time-since-last-update. + body := &trackingReader{r: res.Body, n: &progress} + + return chunked.Put(cs.Chunk, cs.Digest, body) }) } } @@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error { return err } - // store the manifest blob md := blob.DigestFromBytes(m.Data) if err := blob.PutBytes(c, md, m.Data); err != nil { return err } - - // commit the manifest with a link return c.Link(m.Name, md) } @@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se } blobURL := res.Header.Get("Content-Location") + var size int64 s := bufio.NewScanner(res.Body) s.Split(bufio.ScanWords) for { if !s.Scan() { if s.Err() != nil { yield(chunksum{}, s.Err()) + } else if size != l.Size { + yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size)) } return } @@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se return } + size += chunk.Size() + if size > l.Size { + yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size)) + return + } + cs := chunksum{ URL: blobURL, Chunk: chunk, diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 30fb58ab7..305295435 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -17,6 +17,7 @@ import ( "reflect" "slices" "strings" + "sync" "testing" "time" @@ -56,21 +57,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error // newClient constructs a cache with predefined manifests for testing. The manifests are: // -// empty: no data -// zero: no layers -// single: one layer with the contents "exists" -// multiple: two layers with the contents "exists" and "here" -// notfound: a layer that does not exist in the cache -// null: one null layer (e.g. [null]) -// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) -// invalid: a layer with invalid JSON data +// empty: no data +// zero: no layers +// single: one layer with the contents "exists" +// multiple: two layers with the contents "exists" and "here" +// notfound: a layer that does not exist in the cache +// null: one null layer (e.g. [null]) +// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) +// invalid: a layer with invalid JSON data // // Tests that want to ensure the client does not communicate with the upstream // registry should pass a nil handler, which will cause a panic if // communication is attempted. // // To simulate a network error, pass a handler that returns a 499 status code. -func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { +func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) { t.Helper() c, err := blob.Open(t.TempDir()) @@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { r := &Registry{ Cache: c, HTTPClient: &http.Client{ - Transport: recordRoundTripper(h), + Transport: recordRoundTripper(upstreamRegistry), }, } @@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) { } }) } + +func TestPullChunksums(t *testing.T) { + check := testutil.Checker(t) + + content := "hello" + var chunksums string + contentDigest := func() blob.Digest { + return blob.DigestFromBytes(content) + } + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/manifests/latest"): + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content)) + case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()): + loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest()) + w.Header().Set("Content-Location", loc) + io.WriteString(w, chunksums) + case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()): + http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content)) + default: + t.Errorf("unexpected request: %v", r) + http.NotFound(w, r) + } + }) + + rc.MaxStreams = 1 // prevent concurrent chunk downloads + rc.ChunkingThreshold = 1 // for all blobs to be chunked + + var mu sync.Mutex + var reads []int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Logf("Update: %v %d %v", l, n, err) + mu.Lock() + reads = append(reads, n) + mu.Unlock() + }, + }) + + chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n", + blob.DigestFromBytes("hel"), + blob.DigestFromBytes("lo"), + ) + err := rc.Pull(ctx, "test") + check(err) + if !slices.Equal(reads, []int64{0, 3, 5}) { + t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5}) + } + + mw, err := rc.Resolve(t.Context(), "test") + check(err) + mg, err := rc.ResolveLocal("test") + check(err) + if !reflect.DeepEqual(mw, mg) { + t.Errorf("mw = %v; mg = %v", mw, mg) + } + for i := range mg.Layers { + _, err = c.Get(mg.Layers[i].Digest) + if err != nil { + t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err) + } + } + + // missing chunks + content = "llama" + chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll")) + err = rc.Pull(ctx, "missingchunks") + if err == nil { + t.Error("expected error because of missing chunks") + } +}