This commit is contained in:
Blake Mizerany 2024-04-03 20:52:27 -07:00
parent f7cfe946dc
commit 76a202c04e
5 changed files with 99 additions and 97 deletions

View File

@ -99,7 +99,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
// commit the manifest to the registry // commit the manifest to the registry
requirements, err = c.Push(r.Context(), params.Name, man, &registry.PushParams{ requirements, err = c.Push(r.Context(), params.Name, man, &registry.PushParams{
Uploaded: uploads, CompleteParts: uploads,
}) })
if err != nil { if err != nil {
return err return err

View File

@ -23,7 +23,7 @@ type PushRequest struct {
// Parts is a list of upload parts that the client upload in the previous // Parts is a list of upload parts that the client upload in the previous
// push. // push.
Uploaded []CompletePart `json:"part_uploads"` CompleteParts []CompletePart `json:"part_uploads"`
} }
type Requirement struct { type Requirement struct {

View File

@ -24,7 +24,7 @@ func (c *Client) oclient() *ollama.Client {
} }
type PushParams struct { type PushParams struct {
Uploaded []apitype.CompletePart CompleteParts []apitype.CompletePart
} }
// Push pushes a manifest to the server. // Push pushes a manifest to the server.
@ -32,9 +32,9 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushP
p = cmp.Or(p, &PushParams{}) p = cmp.Or(p, &PushParams{})
// TODO(bmizerany): backoff // TODO(bmizerany): backoff
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{ v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Ref: ref, Ref: ref,
Manifest: manifest, Manifest: manifest,
Uploaded: p.Uploaded, CompleteParts: p.CompleteParts,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -101,9 +101,9 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
} }
completePartsByUploadID := make(map[string]completeParts) completePartsByUploadID := make(map[string]completeParts)
for _, pu := range pr.Uploaded { for _, mcp := range pr.CompleteParts {
// parse the URL // parse the URL
u, err := url.Parse(pu.URL) u, err := url.Parse(mcp.URL)
if err != nil { if err != nil {
return err return err
} }
@ -117,8 +117,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber") return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
} }
etag := pu.ETag if mcp.ETag == "" {
if etag == "" {
return oweb.Mistake("invalid", "etag", "missing") return oweb.Mistake("invalid", "etag", "missing")
} }
cp, ok := completePartsByUploadID[uploadID] cp, ok := completePartsByUploadID[uploadID]
@ -128,7 +127,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
} }
cp.parts = append(cp.parts, minio.CompletePart{ cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber, PartNumber: partNumber,
ETag: etag, ETag: mcp.ETag,
}) })
completePartsByUploadID[uploadID] = cp completePartsByUploadID[uploadID] = cp
} }

View File

@ -28,15 +28,22 @@ import (
"kr.dev/diff" "kr.dev/diff"
) )
func testPush(t *testing.T, chunkSize int64) { func TestPushBasic(t *testing.T) {
t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) { const MB = 1024 * 1024
mc := startMinio(t, true)
const MB = 1024 * 1024 mc := startMinio(t, true)
// Upload two small layers and one large layer that will defer func() {
// trigger a multipart upload. mcc := &minio.Core{Client: mc}
manifest := []byte(`{ // fail if there are any incomplete uploads
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
t.Errorf("incomplete: %v", x)
}
}()
// Upload two small layers and one large layer that will
// trigger a multipart upload.
manifest := []byte(`{
"layers": [ "layers": [
{"digest": "sha256-1", "size": 1}, {"digest": "sha256-1", "size": 1},
{"digest": "sha256-2", "size": 2}, {"digest": "sha256-2", "size": 2},
@ -44,106 +51,100 @@ func testPush(t *testing.T, chunkSize int64) {
] ]
}`) }`)
const ref = "registry.ollama.ai/x/y:latest+Z" const ref = "registry.ollama.ai/x/y:latest+Z"
hs := httptest.NewServer(&Server{ hs := httptest.NewServer(&Server{
minioClient: mc, minioClient: mc,
UploadChunkSize: 5 * MB, UploadChunkSize: 5 * MB,
}) })
t.Cleanup(hs.Close) t.Cleanup(hs.Close)
c := &Client{BaseURL: hs.URL} c := &Client{BaseURL: hs.URL}
requirements, err := c.Push(context.Background(), ref, manifest, nil) requirements, err := c.Push(context.Background(), ref, manifest, nil)
if err != nil {
t.Fatal(err)
}
if len(requirements) < 3 {
t.Errorf("expected at least 3 requirements; got %d", len(requirements))
t.Logf("requirements: %v", requirements)
}
var uploaded []apitype.CompletePart
for i, r := range requirements {
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
uploaded = append(uploaded, cp)
}
if len(requirements) < 3 { requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
t.Fatalf("expected at least 3 requirements; got %d", len(requirements)) CompleteParts: uploaded,
t.Logf("requirements: %v", requirements) })
} if err != nil {
t.Fatal(err)
}
if len(requirements) != 0 {
t.Errorf("unexpected requirements: %v", requirements)
}
var uploaded []apitype.CompletePart var paths []string
for i, r := range requirements { keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size) Recursive: true,
})
for k := range keys {
paths = append(paths, k.Key)
}
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size) t.Logf("paths: %v", paths)
if err != nil {
t.Fatal(err)
}
uploaded = append(uploaded, cp)
}
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{ diff.Test(t, t.Errorf, paths, []string{
Uploaded: uploaded, "blobs/sha256-1",
}) "blobs/sha256-2",
if err != nil { "blobs/sha256-3",
t.Fatal(err) "manifests/registry.ollama.ai/x/y/latest/Z",
} })
if len(requirements) != 0 {
t.Fatalf("unexpected requirements: %v", requirements)
}
var paths []string obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ if err != nil {
Recursive: true, t.Fatal(err)
}) }
for k := range keys { defer obj.Close()
paths = append(paths, k.Key)
}
t.Logf("paths: %v", paths) var gotM apitype.Manifest
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, paths, []string{ diff.Test(t, t.Errorf, gotM, apitype.Manifest{
"blobs/sha256-1", Layers: []apitype.Layer{
"blobs/sha256-2", {Digest: "sha256-1", Size: 1},
"blobs/sha256-3", {Digest: "sha256-2", Size: 2},
"manifests/registry.ollama.ai/x/y/latest/Z", {Digest: "sha256-3", Size: 3},
}) },
})
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) // checksum the blobs
for i, l := range gotM.Layers {
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer obj.Close() defer obj.Close()
var gotM apitype.Manifest info, err := obj.Stat()
if err := json.NewDecoder(obj).Decode(&gotM); err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
diff.Test(t, t.Errorf, gotM, apitype.Manifest{ if msg := checkABCs(obj, int(l.Size)); msg != "" {
Layers: []apitype.Layer{ t.Errorf("[%d] %s", i, msg)
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 3},
},
})
// checksum the blobs
for i, l := range gotM.Layers {
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
info, err := obj.Stat()
if err != nil {
t.Fatal(err)
}
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
if msg := checkABCs(obj, int(l.Size)); msg != "" {
t.Errorf("[%d] %s", i, msg)
}
} }
}) }
}
func TestPush(t *testing.T) {
testPush(t, 0)
testPush(t, 1)
} }
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
@ -318,9 +319,11 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
if err != nil { if err != nil {
t.Fatalf("startMinio: %v", err) t.Fatalf("startMinio: %v", err)
} }
if mc.IsOnline() { // try list buckets to see if server is up
if _, err := mc.ListBuckets(ctx); err == nil {
break break
} }
t.Logf("startMinio: server is offline; retrying")
} }
if trace { if trace {