x/registry: fixing tests wip

This commit is contained in:
Blake Mizerany 2024-04-03 16:37:27 -07:00
parent 005b6373e2
commit f7cfe946dc
3 changed files with 68 additions and 77 deletions

View File

@ -4,9 +4,11 @@ import (
"cmp" "cmp"
"context" "context"
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"bllamo.com/client/ollama" "bllamo.com/client/ollama"
"bllamo.com/registry/apitype" "bllamo.com/registry/apitype"
@ -40,23 +42,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushP
return v.Requirements, nil return v.Requirements, nil
} }
func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) { func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
sr := io.NewSectionReader(file, off, size) var zero apitype.CompletePart
req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr) if off < 0 {
if err != nil { return zero, errors.New("off must be >0")
return "", err }
file := io.NewSectionReader(body, off, n)
req, err := http.NewRequest("PUT", url, file)
if err != nil {
return zero, err
}
req.ContentLength = n
// TODO(bmizerany): take content type param
req.Header.Set("Content-Type", "text/plain")
if n >= 0 {
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
} }
req.ContentLength = size
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", err return zero, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != 200 { if res.StatusCode != 200 {
return "", parseS3Error(res) e := parseS3Error(res)
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
} }
return res.Header.Get("ETag"), nil etag := strings.Trim(res.Header.Get("ETag"), `"`)
cp := apitype.CompletePart{
URL: url,
ETag: etag,
// TODO(bmizerany): checksum
}
return cp, nil
} }
type s3Error struct { type s3Error struct {

View File

@ -6,7 +6,6 @@ import (
"cmp" "cmp"
"context" "context"
"errors" "errors"
"fmt"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@ -131,7 +130,6 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
PartNumber: partNumber, PartNumber: partNumber,
ETag: etag, ETag: etag,
}) })
fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
completePartsByUploadID[uploadID] = cp completePartsByUploadID[uploadID] = cp
} }

View File

@ -11,13 +11,11 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"strconv" "strconv"
"strings"
"syscall" "syscall"
"testing" "testing"
"time" "time"
@ -30,8 +28,6 @@ import (
"kr.dev/diff" "kr.dev/diff"
) )
const abc = "abcdefghijklmnopqrstuvwxyz"
func testPush(t *testing.T, chunkSize int64) { func testPush(t *testing.T, chunkSize int64) {
t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) { t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
mc := startMinio(t, true) mc := startMinio(t, true)
@ -71,15 +67,11 @@ func testPush(t *testing.T, chunkSize int64) {
for i, r := range requirements { for i, r := range requirements {
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size) t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
body := strings.NewReader(abc) cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
uploaded = append(uploaded, apitype.CompletePart{ uploaded = append(uploaded, cp)
URL: r.URL,
ETag: etag,
})
} }
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{ requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
@ -142,15 +134,8 @@ func testPush(t *testing.T, chunkSize int64) {
} }
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size) t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
data, err := io.ReadAll(obj) if msg := checkABCs(obj, int(l.Size)); msg != "" {
if err != nil { t.Errorf("[%d] %s", i, msg)
t.Fatal(err)
}
got := string(data)
want := abc[:l.Size]
if got != want {
t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
} }
} }
}) })
@ -161,44 +146,6 @@ func TestPush(t *testing.T) {
testPush(t, 1) testPush(t, 1)
} }
func pushLayer(body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
var zero apitype.CompletePart
if off < 0 {
return zero, errors.New("off must be >0")
}
file := io.NewSectionReader(body, off, n)
req, err := http.NewRequest("PUT", url, file)
if err != nil {
return zero, err
}
req.ContentLength = n
// TODO(bmizerany): take content type param
req.Header.Set("Content-Type", "text/plain")
if n >= 0 {
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return zero, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
e := parseS3Error(res)
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
}
etag := strings.Trim(res.Header.Get("ETag"), `"`)
cp := apitype.CompletePart{
URL: url,
ETag: etag,
// TODO(bmizerany): checksum
}
return cp, nil
}
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
// presigning a multipart upload, uploading the parts, and completing the // presigning a multipart upload, uploading the parts, and completing the
// upload. It is for future reference and should not be deleted. This flow // upload. It is for future reference and should not be deleted. This flow
@ -230,7 +177,7 @@ func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
t.Logf("[partNumber=%d]: %v", partNumber, u) t.Logf("[partNumber=%d]: %v", partNumber, u)
var body abcReader var body abcReader
cp, err := pushLayer(&body, u.String(), c.Offset, c.N) cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
if err != nil { if err != nil {
t.Fatalf("[partNumber=%d]: %v", partNumber, err) t.Fatalf("[partNumber=%d]: %v", partNumber, err)
} }
@ -306,7 +253,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
// explicitly setting trace to true. // explicitly setting trace to true.
trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "") trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
dir := t.TempDir() + "-keep" // prevent tempdir from auto delete dir := t.TempDir()
t.Cleanup(func() { t.Cleanup(func() {
// TODO(bmizerany): trim temp dir based on dates so that // TODO(bmizerany): trim temp dir based on dates so that
@ -317,19 +264,18 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
var e *exec.ExitError var e *exec.ExitError
if errors.As(err, &e) { if errors.As(err, &e) {
if !e.Exited() { if e.Exited() {
// died due to our signal
return return
} }
t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
t.Errorf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode()) t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
t.Errorf("startMinio: %s exited: %v", cmd.Path, e.Exited()) t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
} else { } else {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return return
} }
t.Errorf("startMinio: %s exit error: %v", cmd.Path, err) t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
} }
} }
} }
@ -343,6 +289,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
} }
t.Logf(">> minio: minio server %s", dir) t.Logf(">> minio: minio server %s", dir)
addr := availableAddr() addr := availableAddr()
cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir) cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
cmd.Env = os.Environ() cmd.Env = os.Environ()
@ -463,3 +410,28 @@ func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
} }
return len(p), nil return len(p), nil
} }
func checkABCs(r io.Reader, size int) (reason string) {
h := sha256.New()
n, err := io.CopyN(h, &abcReader{}, int64(size))
if err != nil {
return err.Error()
}
if n != int64(size) {
panic("short read; should not happen")
}
want := h.Sum(nil)
h = sha256.New()
n, err = io.Copy(h, r)
if err != nil {
return err.Error()
}
if n != int64(size) {
return fmt.Sprintf("got len(r) = %d; want %d", n, size)
}
got := h.Sum(nil)
if !bytes.Equal(got, want) {
return fmt.Sprintf("got sum = %x; want %x", got, want)
}
return ""
}