diff --git a/x/registry/server.go b/x/registry/server.go index 91659767c..b8dcafca2 100644 --- a/x/registry/server.go +++ b/x/registry/server.go @@ -76,6 +76,7 @@ func (s *Server) uploadChunkSize() int64 { func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { const bucketTODO = "test" + const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body) if err != nil { @@ -156,28 +157,43 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } if !pushed { key := path.Join("blobs", l.Digest) - uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{}) - if err != nil { - return err - } - for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) { - const timeToStartUpload = 15 * time.Minute - - signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{ - "UploadId": []string{uploadID}, - "PartNumber": []string{strconv.Itoa(partNumber)}, - "ContentLength": []string{strconv.FormatInt(c.Size, 10)}, - }) + if l.Size < minimumMultipartSize { + // single part upload + signedURL, err := s.mc().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute) if err != nil { return err } - requirements = append(requirements, apitype.Requirement{ Digest: l.Digest, - Offset: c.Offset, - Size: c.Size, + Offset: 0, + Size: l.Size, URL: signedURL.String(), }) + } else { + key := path.Join("blobs", l.Digest) + uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{}) + if err != nil { + return err + } + for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) { + const timeToStartUpload = 15 * time.Minute + + signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{ + "UploadId": []string{uploadID}, + "PartNumber": []string{strconv.Itoa(partNumber)}, + "ContentLength": []string{strconv.FormatInt(c.N, 10)}, + }) + if err != nil { + return err + } + + requirements = append(requirements, apitype.Requirement{ + Digest: l.Digest, + Offset: c.Offset, + Size: c.N, + URL: signedURL.String(), + }) + } } } } diff --git a/x/registry/server_test.go b/x/registry/server_test.go index 8cb1ecc18..1952feb9a 100644 --- a/x/registry/server_test.go +++ b/x/registry/server_test.go @@ -1,22 +1,27 @@ package registry import ( - "bufio" + "bytes" "context" + "crypto/sha256" "encoding/json" "errors" "fmt" "io" "net" + "net/http" "net/http/httptest" + "net/url" "os" "os/exec" + "strconv" "strings" "testing" "time" "bllamo.com/registry/apitype" "bllamo.com/utils/backoff" + "bllamo.com/utils/upload" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "kr.dev/diff" @@ -149,6 +154,131 @@ func TestPush(t *testing.T) { testPush(t, 1) } +func pushLayer(body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) { + var zero apitype.CompletePart + if off < 0 { + return zero, errors.New("off must be >0") + } + + file := io.NewSectionReader(body, off, n) + req, err := http.NewRequest("PUT", url, file) + if err != nil { + return zero, err + } + req.ContentLength = n + + // TODO(bmizerany): take content type param + req.Header.Set("Content-Type", "text/plain") + + if n >= 0 { + req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1)) + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return zero, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + e := parseS3Error(res) + return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e) + } + etag := strings.Trim(res.Header.Get("ETag"), `"`) + cp := apitype.CompletePart{ + URL: url, + ETag: etag, + // TODO(bmizerany): checksum + } + return cp, nil +} + +// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of +// presigning a multipart upload, uploading the parts, and completing the +// upload. It is for future reference and should not be deleted. This flow +// is tricky and if we get it wrong in our server, we can refer back to this +// as a "back to basics" test/reference. +func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) { + mc := startMinio(t, false) + mcc := &minio.Core{Client: mc} + + uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{}) + if err != nil { + t.Fatal(err) + } + + var completed []minio.CompletePart + const size int64 = 10 * 1024 * 1024 + const chunkSize = 5 * 1024 * 1024 + + for partNumber, c := range upload.Chunks(size, chunkSize) { + u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{ + "partNumber": {strconv.Itoa(partNumber)}, + "uploadId": {uploadID}, + }) + if err != nil { + t.Fatalf("[partNumber=%d]: %v", partNumber, err) + } + t.Logf("[partNumber=%d]: %v", partNumber, u) + + var body abcReader + cp, err := pushLayer(&body, u.String(), c.Offset, c.N) + if err != nil { + t.Fatalf("[partNumber=%d]: %v", partNumber, err) + } + t.Logf("completed part: %v", cp) + + // behave like server here (don't cheat and use partNumber) + // instead get partNumber from the URL + retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber")) + if err != nil { + t.Fatalf("[partNumber=%d]: %v", partNumber, err) + } + + completed = append(completed, minio.CompletePart{ + PartNumber: retPartNumber, + ETag: cp.ETag, + }) + } + + defer func() { + // fail if there are any incomplete uploads + for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) { + t.Errorf("incomplete: %v", x) + } + }() + + info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{}) + if err != nil { + t.Fatal(err) + } + + t.Logf("completed: %v", info) + + // Check key in bucket + obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() + + h := sha256.New() + if _, err := io.Copy(h, obj); err != nil { + t.Fatal(err) + } + gotSum := h.Sum(nil) + + h.Reset() + var body abcReader + if _, err := io.CopyN(h, &body, size); err != nil { + t.Fatal(err) + } + wantSum := h.Sum(nil) + + if !bytes.Equal(gotSum, wantSum) { + t.Errorf("got sum = %x; want %x", gotSum, wantSum) + } +} + func availableAddr() string { l, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -161,30 +291,18 @@ func availableAddr() string { func startMinio(t *testing.T, debug bool) *minio.Client { t.Helper() - dir := t.TempDir() - t.Logf(">> minio data dir: %s", dir) + dir := t.TempDir() + "-keep" // prevent tempdir from auto delete + + t.Cleanup(func() { + // TODO(bmizerany): trim temp dir based on dates so that + // future runs may be able to inspect results for some time. + }) + + t.Logf(">> minio: minio server %s", dir) addr := availableAddr() cmd := exec.Command("minio", "server", "--address", addr, dir) cmd.Env = os.Environ() - if debug { - stdout, err := cmd.StdoutPipe() - if err != nil { - t.Fatal(err) - } - doneLogging := make(chan struct{}) - t.Cleanup(func() { - <-doneLogging - }) - go func() { - defer close(doneLogging) - sc := bufio.NewScanner(stdout) - for sc.Scan() { - t.Logf("minio: %s", sc.Text()) - } - }() - } - // TODO(bmizerany): wait delay etc... if err := cmd.Start(); err != nil { t.Fatal(err) @@ -227,6 +345,14 @@ func startMinio(t *testing.T, debug bool) *minio.Client { } } + if debug { + // I was using mc.TraceOn here but wasn't giving any output + // that was meaningful. I really want all server logs, not + // client HTTP logs. We have places we do not use a minio + // client and cannot or do not want to use a minio client. + panic("TODO") + } + if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil { t.Fatal(err) } @@ -251,3 +377,28 @@ func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) { }) return ctx, func() { close(done) } } + +// abcReader repeats the string s infinitely. +type abcReader struct { + pos int +} + +const theABCs = "abcdefghijklmnopqrstuvwxyz" + +func (r *abcReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = theABCs[r.pos] + r.pos++ + if r.pos == len(theABCs) { + r.pos = 0 + } + } + return len(p), nil +} + +func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) { + for i := range p { + p[i] = theABCs[(off+int64(i))%int64(len(theABCs))] + } + return len(p), nil +} diff --git a/x/utils/upload/upload.go b/x/utils/upload/upload.go index c7447b54d..d70833c7b 100644 --- a/x/utils/upload/upload.go +++ b/x/utils/upload/upload.go @@ -8,7 +8,7 @@ import ( type Chunk[I constraints.Integer] struct { Offset I - Size I + N I } // Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset @@ -21,7 +21,9 @@ func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] { var n int for off := I(0); off < size; off += chunkSize { n++ - yield(n, Chunk[I]{off, min(chunkSize, size-off)}) + if !yield(n, Chunk[I]{off, min(chunkSize, size-off)}) { + return + } } } } diff --git a/x/utils/upload/upload_test.go b/x/utils/upload/upload_test.go index 44ad7f211..09e659111 100644 --- a/x/utils/upload/upload_test.go +++ b/x/utils/upload/upload_test.go @@ -35,3 +35,10 @@ func TestChunks(t *testing.T) { diff.Test(t, t.Errorf, got, want) } + +func TestChunksBreak(t *testing.T) { + for _, _ = range Chunks(1, 1) { + return + } + t.Fatal("expected break") +}