... wip still broke

This commit is contained in:
Blake Mizerany 2024-04-03 22:15:58 -07:00
parent 76a202c04e
commit def4d902bf
6 changed files with 85 additions and 20 deletions

View File

@ -15,8 +15,8 @@ import (
// Common API Errors // Common API Errors
var ( var (
errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified") errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
errRefNotFound = oweb.Mistake("not_found", "name", "no such model") errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
) )
type Server struct { type Server struct {

View File

@ -253,6 +253,17 @@ func ParseRef(s string) Ref {
return r return r
} }
// Complete is the same as ParseRef(s).Complete().
//
// Future versions may be faster than calling ParseRef(s).Complete(), so if
// need to know if a ref is complete and don't need the ref, use this
// function.
func Complete(s string) bool {
// TODO(bmizerany): fast-path this with a quick scan withput
// allocating strings
return ParseRef(s).Complete()
}
func (r Ref) Valid() bool { func (r Ref) Valid() bool {
// Name is required // Name is required
if !isValidPart(r.name) { if !isValidPart(r.name) {

View File

@ -92,12 +92,23 @@ type Error struct {
// Field is the field in the request that caused the error, if any. // Field is the field in the request that caused the error, if any.
Field string `json:"field,omitempty"` Field string `json:"field,omitempty"`
// Value is the value of the field that caused the error, if any.
Value string `json:"value,omitempty"`
} }
func (e *Error) Error() string { func (e *Error) Error() string {
var b strings.Builder var b strings.Builder
b.WriteString("ollama: ") b.WriteString("ollama: ")
b.WriteString(e.Code) b.WriteString(e.Code)
if e.Field != "" {
b.WriteString(" ")
b.WriteString(e.Field)
}
if e.Value != "" {
b.WriteString(": ")
b.WriteString(e.Value)
}
if e.Message != "" { if e.Message != "" {
b.WriteString(": ") b.WriteString(": ")
b.WriteString(e.Message) b.WriteString(e.Message)

View File

@ -21,12 +21,13 @@ func Missing(field string) error {
} }
} }
func Mistake(code, field, message string) error { func Invalid(field, value, format string, args ...any) error {
return &ollama.Error{ return &ollama.Error{
Status: 400, Status: 400,
Code: code, Code: "invalid",
Field: field, Field: field,
Message: fmt.Sprintf("%s: %s", field, message), Value: value,
Message: fmt.Sprintf(format, args...),
} }
} }
@ -69,7 +70,7 @@ func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
if errors.As(err, &se) { if errors.As(err, &se) {
msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type) msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type)
} }
return nil, Mistake("invalid_json", field, msg) return nil, Invalid("invalid_json", field, "", msg)
} }
func DecodeJSON[T any](r io.Reader) (*T, error) { func DecodeJSON[T any](r io.Reader) (*T, error) {

View File

@ -84,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
ref := blob.ParseRef(pr.Ref) ref := blob.ParseRef(pr.Ref)
if !ref.Complete() { if !ref.Complete() {
return oweb.Mistake("invalid", "name", "must be complete") return oweb.Invalid("name", pr.Ref, "must be complete")
} }
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
@ -107,24 +107,30 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
return err return err
} }
q := u.Query() q := u.Query()
// Check if this is a part upload, if not, skip
uploadID := q.Get("uploadId") uploadID := q.Get("uploadId")
if uploadID == "" { if uploadID == "" {
// not a part upload // not a part upload
continue continue
} }
partNumber, err := strconv.Atoi(q.Get("partNumber"))
// PartNumber is required
queryPartNumber := q.Get("partNumber")
partNumber, err := strconv.Atoi(queryPartNumber)
if err != nil { if err != nil {
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber") return oweb.Invalid("partNumber", queryPartNumber, "invalid or missing PartNumber")
} }
// ETag is required
if mcp.ETag == "" { if mcp.ETag == "" {
return oweb.Mistake("invalid", "etag", "missing") return oweb.Missing("etag")
}
cp, ok := completePartsByUploadID[uploadID]
if !ok {
cp = completeParts{key: u.Path}
completePartsByUploadID[uploadID] = cp
} }
cp := completePartsByUploadID[uploadID]
cp.key = u.Path
cp.parts = append(cp.parts, minio.CompletePart{ cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber, PartNumber: partNumber,
ETag: mcp.ETag, ETag: mcp.ETag,
@ -136,8 +142,11 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
var zeroOpts minio.PutObjectOptions var zeroOpts minio.PutObjectOptions
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts) _, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
if err != nil { if err != nil {
// log and continue; put backpressure on the client var e minio.ErrorResponse
log.Printf("error completing upload: %v", err) if errors.As(err, &e) && e.Code == "NoSuchUpload" {
return oweb.Invalid("uploadId", uploadID, "unknown uploadId")
}
return err
} }
} }

View File

@ -28,6 +28,39 @@ import (
"kr.dev/diff" "kr.dev/diff"
) )
// const ref = "registry.ollama.ai/x/y:latest+Z"
// const manifest = `{
// "layers": [
// {"digest": "sha256-1", "size": 1},
// {"digest": "sha256-2", "size": 2},
// {"digest": "sha256-3", "size": 3}
// ]
// }`
// ts := newTestServer(t)
// ts.pushNotOK(ref, `{}`, &ollama.Error{
// Status: 400,
// Code: "invalid",
// Message: "name must be fully qualified",
// })
// ts.push(ref, `{
// "layers": [
// {"digest": "sha256-1", "size": 1},
// {"digest": "sha256-2", "size": 2},
// {"digest": "sha256-3", "size": 3}
// ]
// }`)
type tWriter struct {
t *testing.T
}
func (w tWriter) Write(p []byte) (n int, err error) {
w.t.Logf("%s", p)
return len(p), nil
}
func TestPushBasic(t *testing.T) { func TestPushBasic(t *testing.T) {
const MB = 1024 * 1024 const MB = 1024 * 1024
@ -41,6 +74,8 @@ func TestPushBasic(t *testing.T) {
} }
}() }()
const ref = "registry.ollama.ai/x/y:latest+Z"
// Upload two small layers and one large layer that will // Upload two small layers and one large layer that will
// trigger a multipart upload. // trigger a multipart upload.
manifest := []byte(`{ manifest := []byte(`{
@ -51,8 +86,6 @@ func TestPushBasic(t *testing.T) {
] ]
}`) }`)
const ref = "registry.ollama.ai/x/y:latest+Z"
hs := httptest.NewServer(&Server{ hs := httptest.NewServer(&Server{
minioClient: mc, minioClient: mc,
UploadChunkSize: 5 * MB, UploadChunkSize: 5 * MB,