diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 665defd58..409932bfd 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { return err } -func canRetry(err error) bool { - var re *Error - if !errors.As(err, &re) { - return false - } - return re.Status >= 500 -} - // trackingReader is an io.Reader that tracks the number of bytes read and // calls the update function with the layer, the number of bytes read. // @@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error { break } + cacheKey := fmt.Sprintf( + "v1 pull chunksum %s %s %d-%d", + l.Digest, + cs.Digest, + cs.Chunk.Start, + cs.Chunk.End, + ) + cacheKeyDigest := blob.DigestFromBytes(cacheKey) + _, err := c.Get(cacheKeyDigest) + if err == nil { + received.Add(cs.Chunk.Size()) + t.update(l, cs.Chunk.Size(), ErrCached) + continue + } + wg.Add(1) g.Go(func() (err error) { defer func() { if err == nil { + // Ignore cache key write errors for now. We've already + // reported to trace that the chunk is complete. + // + // Ideally, we should only report completion to trace + // after successful cache commit. This current approach + // works but could trigger unnecessary redownloads if + // the checkpoint key is missing on next pull. + // + // Not incorrect, just suboptimal - fix this in a + // future update. + _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + received.Add(cs.Chunk.Size()) } else { - err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) + t.update(l, 0, err) } wg.Done() }() @@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { return err } if received.Load() != expected { - return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected) + return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected) } md := blob.DigestFromBytes(m.Data) @@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer { return nil } +func (m *Manifest) All() iter.Seq[*Layer] { + return func(yield func(*Layer) bool) { + if !yield(m.Config) { + return + } + for _, l := range m.Layers { + if !yield(l) { + return + } + } + } +} + +func (m *Manifest) Size() int64 { + var size int64 + if m.Config != nil { + size += m.Config.Size + } + for _, l := range m.Layers { + size += l.Size + } + return size +} + // MarshalJSON implements json.Marshaler. // // NOTE: It adds an empty config object to the manifest, which is required by @@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se return } - // A chunksums response is a sequence of chunksums in a - // simple, easy to parse line-oriented format. + // The response is a sequence of chunksums. // - // Example: + // Chunksums are chunks of a larger blob that can be + // downloaded and verified independently. // - // >> GET /v2///chunksums/ + // The chunksums endpoint is a GET request that returns a + // sequence of chunksums in the following format: // - // << HTTP/1.1 200 OK - // << Content-Location: - // << - // << - - // << ... + // > GET /v2///chunksums/ // - // The blobURL is the URL to download the chunks from. + // < HTTP/1.1 200 OK + // < Content-Location: + // < + // < - + // < ... + // + // The is the URL to download the chunks from and + // each is the digest of the chunk, and - + // is the range the chunk in the blob. + // + // Ranges may be used directly in Range headers like + // "bytes=-". + // + // The chunksums returned are guaranteed to be contiguous and + // include all bytes of the layer. If the stream is cut short, + // clients should retry. chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s", scheme, diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index f8136c06f..80d39b765 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -9,17 +9,14 @@ import ( "fmt" "io" "io/fs" - "math/rand/v2" + "net" "net/http" "net/http/httptest" "os" - "path" "reflect" - "slices" "strings" - "sync" + "sync/atomic" "testing" - "time" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/testutil" @@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) { } } -func checkNotExist(t *testing.T, err error) { - t.Helper() - if !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("err = %v; want fs.ErrNotExist", err) - } -} - func TestRegistryPullInvalidName(t *testing.T) { - rc, _ := newClient(t, nil) + rc, _ := newRegistryClient(t, nil) err := rc.Pull(t.Context(), "://") if !errors.Is(err, ErrNameInvalid) { t.Errorf("err = %v; want %v", err, ErrNameInvalid) @@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) { } for _, resp := range cases { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, resp) }) - err := rc.Pull(t.Context(), "x") + err := rc.Pull(t.Context(), "http://example.com/a/b") if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want invalid manifest", err) } } } -func TestRegistryPullNotCached(t *testing.T) { - check := testutil.Checker(t) - - var c *blob.DiskCache - var rc *Registry - - d := blob.DigestFromBytes("some data") - rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/blobs/") { - io.WriteString(w, "some data") - return - } - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d) - }) - - // Confirm that the layer does not exist locally - _, err := rc.ResolveLocal("model") - checkNotExist(t, err) - - _, err = c.Get(d) - checkNotExist(t, err) - - err = rc.Pull(t.Context(), "model") - check(err) - - mw, err := rc.Resolve(t.Context(), "model") - check(err) - mg, err := rc.ResolveLocal("model") - check(err) - if !reflect.DeepEqual(mw, mg) { - t.Errorf("mw = %v; mg = %v", mw, mg) - } - - // Confirm successful download - info, err := c.Get(d) - check(err) - if info.Digest != d { - t.Errorf("info.Digest = %v; want %v", info.Digest, d) - } - if info.Size != 9 { - t.Errorf("info.Size = %v; want %v", info.Size, 9) - } - - data, err := os.ReadFile(c.GetFile(d)) - check(err) - if string(data) != "some data" { - t.Errorf("data = %q; want %q", data, "exists") - } -} - -func TestRegistryPullCached(t *testing.T) { - cached := blob.DigestFromBytes("exists") - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/blobs/") { - w.WriteHeader(499) // should not be called - return - } - if strings.Contains(r.URL.Path, "/manifests/") { - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached) - } - }) - - var errs []error - var reads []int64 - ctx := WithTrace(t.Context(), &Trace{ - Update: func(d *Layer, n int64, err error) { - t.Logf("update %v %d %v", d, n, err) - reads = append(reads, n) - errs = append(errs, err) - }, - }) - - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - - err := rc.Pull(ctx, "single") - testutil.Check(t, err) - - want := []int64{0, 6} - if !errors.Is(errors.Join(errs...), ErrCached) { - t.Errorf("errs = %v; want %v", errs, ErrCached) - } - if !slices.Equal(reads, want) { - t.Errorf("pairs = %v; want %v", reads, want) - } -} - -func TestRegistryPullManifestNotFound(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - }) - err := rc.Pull(t.Context(), "notfound") - checkErrCode(t, err, 404, "") -} - -func TestRegistryPullResolveRemoteError(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) - }) - err := rc.Pull(t.Context(), "single") - checkErrCode(t, err, 500, "an_error") -} - -func TestRegistryPullResolveRoundtripError(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/manifests/") { - w.WriteHeader(499) // force RoundTrip error - return - } - }) - err := rc.Pull(t.Context(), "single") - if !errors.Is(err, errRoundTrip) { - t.Errorf("err = %v; want %v", err, errRoundTrip) - } -} - -// TestRegistryPullMixedCachedNotCached tests that cached layers do not -// interfere with pulling layers that are not cached -func TestRegistryPullMixedCachedNotCached(t *testing.T) { - x := blob.DigestFromBytes("xxxxxx") - e := blob.DigestFromBytes("exists") - y := blob.DigestFromBytes("yyyyyy") - - for i := range 10 { - t.Logf("iteration %d", i) - - digests := []blob.Digest{x, e, y} - - rand.Shuffle(len(digests), func(i, j int) { - digests[i], digests[j] = digests[j], digests[i] - }) - - manifest := fmt.Sprintf(`{ - "layers": [ - {"digest":"%s","size":6}, - {"digest":"%s","size":6}, - {"digest":"%s","size":6} - ] - }`, digests[0], digests[1], digests[2]) - - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { - switch path.Base(r.URL.Path) { - case "latest": - io.WriteString(w, manifest) - case x.String(): - io.WriteString(w, "xxxxxx") - case e.String(): - io.WriteString(w, "exists") - case y.String(): - io.WriteString(w, "yyyyyy") - default: - panic(fmt.Sprintf("unexpected request: %v", r)) - } - }) - - ctx := WithTrace(t.Context(), &Trace{ - Update: func(l *Layer, n int64, err error) { - t.Logf("update %v %d %v", l, n, err) - }, - }) - - // Check that we pull all layers that we can. - - err := rc.Pull(ctx, "mixed") - if err != nil { - t.Fatal(err) - } - - for _, d := range digests { - info, err := c.Get(d) - if err != nil { - t.Fatalf("Get(%v): %v", d, err) - } - if info.Size != 6 { - t.Errorf("info.Size = %v; want %v", info.Size, 6) - } - } - } -} - func TestRegistryResolveByDigest(t *testing.T) { check := testutil.Checker(t) @@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) { testutil.Check(t, err) } -func TestCanRetry(t *testing.T) { - cases := []struct { - err error - want bool - }{ - {nil, false}, - {errors.New("x"), false}, - {ErrCached, false}, - {ErrManifestInvalid, false}, - {ErrNameInvalid, false}, - {&Error{Status: 100}, false}, - {&Error{Status: 500}, true}, - } - for _, tt := range cases { - if got := canRetry(tt.err); got != tt.want { - t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want) - } - } -} - func TestErrorUnmarshal(t *testing.T) { cases := []struct { name string @@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) { func TestUnlink(t *testing.T) { t.Run("found by name", func(t *testing.T) { - rc, _ := newClient(t, nil) + check := testutil.Checker(t) + + rc, _ := newRegistryClient(t, nil) + // make a blob and link it + d := blob.DigestFromBytes("{}") + err := blob.PutBytes(rc.Cache, d, "{}") + check(err) + err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d) + check(err) // confirm linked - _, err := rc.ResolveLocal("single") - if err != nil { - t.Errorf("unexpected error: %v", err) - } + _, err = rc.ResolveLocal("single") + check(err) // unlink _, err = rc.Unlink("single") - testutil.Check(t, err) + check(err) // confirm unlinked _, err = rc.ResolveLocal("single") @@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) { } }) t.Run("not found by name", func(t *testing.T) { - rc, _ := newClient(t, nil) + rc, _ := newRegistryClient(t, nil) ok, err := rc.Unlink("manifestNotFound") if err != nil { t.Fatal(err) @@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) { }) } -func TestPullChunksums(t *testing.T) { - check := testutil.Checker(t) +// Many tests from here out, in this file are based on a single blob, "abc", +// with the checksum of its sha256 hash. The checksum is: +// +// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad +// +// Using the literal value instead of a constant with fmt.Xprintf calls proved +// to be the most readable and maintainable approach. The sum is consistently +// used in the tests and unique so searches do not yield false positives. - content := "hello" - var chunksums string - contentDigest := func() blob.Digest { - return blob.DigestFromBytes(content) +func checkRequest(t *testing.T, req *http.Request, method, path string) { + t.Helper() + if got := req.URL.Path; got != path { + t.Errorf("URL = %q, want %q", got, path) } - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { - switch { - case strings.Contains(r.URL.Path, "/manifests/latest"): - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content)) - case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()): - loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest()) - w.Header().Set("Content-Location", loc) - io.WriteString(w, chunksums) - case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()): - http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content)) - default: - t.Errorf("unexpected request: %v", r) - http.NotFound(w, r) - } - }) + if req.Method != method { + t.Errorf("Method = %q, want %q", req.Method, method) + } +} - rc.MaxStreams = 1 // prevent concurrent chunk downloads - rc.ChunkingThreshold = 1 // for all blobs to be chunked +func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) { + s := httptest.NewServer(h) + t.Cleanup(s.Close) + cache, err := blob.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } - var mu sync.Mutex - var reads []int64 ctx := WithTrace(t.Context(), &Trace{ Update: func(l *Layer, n int64, err error) { - t.Logf("Update: %v %d %v", l, n, err) - mu.Lock() - reads = append(reads, n) - mu.Unlock() + t.Log("trace:", l.Digest.Short(), n, err) }, }) - chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n", - blob.DigestFromBytes("hel"), - blob.DigestFromBytes("lo"), - ) - err := rc.Pull(ctx, "test") - check(err) - wantReads := []int64{ - 0, // initial signaling of layer pull starting - 3, // first chunk read - 2, // second chunk read - } - if !slices.Equal(reads, wantReads) { - t.Errorf("reads = %v; want %v", reads, wantReads) + rc := &Registry{ + Cache: cache, + HTTPClient: &http.Client{Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, s.Listener.Addr().String()) + }, + }}, } + return rc, ctx +} - mw, err := rc.Resolve(t.Context(), "test") - check(err) - mg, err := rc.ResolveLocal("test") - check(err) - if !reflect.DeepEqual(mw, mg) { - t.Errorf("mw = %v; mg = %v", mw, mg) - } - for i := range mg.Layers { - _, err = c.Get(mg.Layers[i].Digest) - if err != nil { - t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err) +func TestPullChunked(t *testing.T) { + var steps atomic.Int64 + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch steps.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3, 4: + checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + switch rng := r.Header.Get("Range"); rng { + case "bytes=0-1": + io.WriteString(w, "ab") + case "bytes=2-2": + t.Logf("writing c") + io.WriteString(w, "c") + default: + t.Errorf("unexpected range %q", rng) + } + default: + t.Errorf("unexpected steps %d: %v", steps.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) } - } + }) - // missing chunks - content = "llama" - chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll")) - err = rc.Pull(ctx, "missingchunks") - if err == nil { - t.Error("expected error because of missing chunks") + c.ChunkingThreshold = 1 // force chunking + + err := c.Pull(ctx, "http://o.com/library/abc") + testutil.Check(t, err) + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + testutil.Check(t, err) + + if g := steps.Load(); g != 4 { + t.Fatalf("got %d steps, want 4", g) + } +} + +func TestPullCached(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + }) + + check := testutil.Checker(t) + + // Premeptively cache the blob + d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + check(err) + err = blob.PutBytes(c.Cache, d, []byte("abc")) + check(err) + + // Pull only the manifest, which should be enough to resolve the cached blob + err = c.Pull(ctx, "http://o.com/library/abc") + check(err) +} + +func TestPullManifestError(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + w.WriteHeader(http.StatusNotFound) + io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`) + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + if err == nil { + t.Fatalf("expected error") + } + var got *Error + if !errors.Is(err, ErrModelNotFound) { + t.Fatalf("err = %v, want %v", got, ErrModelNotFound) + } +} + +func TestPullLayerError(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `!`) + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + if err == nil { + t.Fatalf("expected error") + } + var want *json.SyntaxError + if !errors.As(err, &want) { + t.Fatalf("err = %T, want %T", err, want) + } +} + +func TestPullLayerChecksumError(t *testing.T) { + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3: + w.WriteHeader(http.StatusNotFound) + io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`) + case 4: + io.WriteString(w, "c") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.MaxStreams = 1 + c.ChunkingThreshold = 1 // force chunking + + var written atomic.Int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + written.Add(n) + }, + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + var got *Error + if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" { + t.Fatalf("err = %v, want %v", err, got) + } + + if g := written.Load(); g != 1 { + t.Fatalf("wrote %d bytes, want 1", g) + } +} + +func TestPullChunksumStreamError(t *testing.T) { + var step atomic.Int64 + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + + // Write one valid chunksum and one invalid chunksum + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid + fmt.Fprint(w, "sha256:!") // invalid + case 3: + io.WriteString(w, "ab") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.ChunkingThreshold = 1 // force chunking + + got := c.Pull(ctx, "http://o.com/library/abc") + if !errors.Is(got, ErrIncomplete) { + t.Fatalf("err = %v, want %v", got, ErrIncomplete) + } +} + +type flushAfterWriter struct { + w io.Writer +} + +func (f *flushAfterWriter) Write(p []byte) (n int, err error) { + n, err = f.w.Write(p) + f.w.(http.Flusher).Flush() // panic if not a flusher + return +} + +func TestPullChunksumStreaming(t *testing.T) { + csr, csw := io.Pipe() + defer csw.Close() + + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing + _, err := io.Copy(fw, csr) + if err != nil { + t.Errorf("copy: %v", err) + } + case 3: + io.WriteString(w, "ab") + case 4: + io.WriteString(w, "c") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.ChunkingThreshold = 1 // force chunking + + update := make(chan int64, 1) + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + if n > 0 { + update <- n + } + }, + }) + + errc := make(chan error, 1) + go func() { + errc <- c.Pull(ctx, "http://o.com/library/abc") + }() + + // Send first chunksum and ensure it kicks off work immediately + fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab")) + if g := <-update; g != 2 { + t.Fatalf("got %d, want 2", g) + } + + // now send the second chunksum and ensure it kicks off work immediately + fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c")) + if g := <-update; g != 1 { + t.Fatalf("got %d, want 1", g) + } + csw.Close() + testutil.Check(t, <-errc) +} + +func TestPullChunksumsCached(t *testing.T) { + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3, 4: + switch rng := r.Header.Get("Range"); rng { + case "bytes=0-1": + io.WriteString(w, "ab") + case "bytes=2-2": + io.WriteString(w, "c") + default: + t.Errorf("unexpected range %q", rng) + } + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.MaxStreams = 1 // force serial processing of chunksums + c.ChunkingThreshold = 1 // force chunking + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + // Cancel the pull after the first chunksum is processed, but before + // the second chunksum is processed (which is waiting because + // MaxStreams=1). This should cause the second chunksum to error out + // leaving the blob incomplete. + ctx = WithTrace(ctx, &Trace{ + Update: func(l *Layer, n int64, err error) { + if n > 0 { + cancel() + } + }, + }) + err := c.Pull(ctx, "http://o.com/library/abc") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want %v", err, context.Canceled) + } + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want nil", err) + } + + // Reset state and pull again to ensure the blob chunks that should + // have been cached are, and the remaining chunk was downloaded, making + // the blob complete. + step.Store(0) + var written atomic.Int64 + var cached atomic.Int64 + ctx = WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + if errors.Is(err, ErrCached) { + cached.Add(n) + } + written.Add(n) + }, + }) + + check := testutil.Checker(t) + + err = c.Pull(ctx, "http://o.com/library/abc") + check(err) + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + check(err) + + if g := written.Load(); g != 3 { + t.Fatalf("wrote %d bytes, want 3", g) + } + if g := cached.Load(); g != 2 { // "ab" should have been cached + t.Fatalf("cached %d bytes, want 3", g) } }