diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index e4c36d7d8..82a8bbca4 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -24,6 +24,7 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "sync/atomic" @@ -53,7 +54,7 @@ var ( // ErrMissingModel is returned when the model part of a name is missing // or invalid. - ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + ErrNameInvalid = errors.New("invalid or missing name") // ErrCached is passed to [Trace.PushUpdate] when a layer already // exists. It is a non-fatal error and is never returned by [Registry.Push]. @@ -205,10 +206,18 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // NameMask, if set, is the name used to convert non-fully qualified + // Mask, if set, is the name used to convert non-fully qualified // names to fully qualified names. If empty, the default mask // ("registry.ollama.ai/library/_:latest") is used. - NameMask string + Mask string +} + +func (r *Registry) completeName(name string) names.Name { + mask := defaultMask + if r.Mask != "" { + mask = names.Parse(r.Mask) + } + return names.Merge(names.Parse(name), mask) } // DefaultRegistry returns a new Registry configured from the environment. The @@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } -type PushParams struct { - // From is an optional destination name for the model. If empty, the - // destination name is the same as the source name. - From string -} - -// parseName parses name using [names.ParseExtended] and then merges the name with the -// default name, and checks that the name is fully qualified. If a digest is -// present, it parse and returns it with the other fields as their zero values. -// -// It returns an error if the name is not fully qualified, or if the digest, if -// any, is invalid. -// -// The scheme is returned as provided by [names.ParseExtended]. -func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) { - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - if !maskName.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask) - } - } - scheme, n, ds := names.ParseExtended(s) - if !n.IsValid() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - n = names.Merge(n, maskName) - if ds != "" { - // Digest is present. Validate it. - d, err = blob.ParseDigest(ds) - if err != nil { - return "", names.Name{}, blob.Digest{}, err - } - } - - // The name check is deferred until after the digest check because we - // say that digests take precedence over names, and so should there - // errors when being parsed. - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - - scheme = cmp.Or(scheme, "https") - return scheme, n, d, nil -} - func (r *Registry) maxStreams() int { n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) @@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 { return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) } +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + // Push pushes the model with the name in the cache to the remote registry. func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { if p == nil { @@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -431,7 +400,7 @@ func canRetry(err error) bool { // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { return err } @@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - _, n, _, err := parseName(name, r.NameMask) - if err != nil { - return false, err + n := r.completeName(name) + if !n.IsFullyQualified() { + return false, fmt.Errorf("%w: %q", ErrNameInvalid, name) } return c.Unlink(n.String()) } @@ -658,9 +627,9 @@ type Layer struct { } // ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.ParseExtended] but the scheme is ignored. +// parsed using [names.Split] but the scheme is ignored. func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.NameMask) + _, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.NameMask) + scheme, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error { } return err } + +type publicError struct { + wrapped error + message string +} + +func withPublicMessagef(err error, message string, args ...any) error { + return publicError{wrapped: err, message: fmt.Sprintf(message, args...)} +} + +func (e publicError) Error() string { return e.message } +func (e publicError) Unwrap() error { return e.wrapped } + +var supportedSchemes = []string{ + "http", + "https", + "https+insecure", +} + +var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) + +// parseName parses and validates an extended name, returning the scheme, name, +// and digest. +// +// If the scheme is empty, scheme will be "https". If an unsupported scheme is +// given, [ErrNameInvalid] wrapped with a display friendly message is returned. +// +// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly +// message is returned. +// +// If the name is not, once merged with the mask, fully qualified, +// [ErrNameInvalid] wrapped with a display friendly message is returned. +func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := names.Split(s) + scheme = cmp.Or(scheme, "https") + if !slices.Contains(supportedSchemes, scheme) { + err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) + return "", names.Name{}, blob.Digest{}, err + } + + var d blob.Digest + if digest != "" { + var err error + d, err = blob.ParseDigest(digest) + if err != nil { + err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest) + return "", names.Name{}, blob.Digest{}, err + } + if name == "" { + // We have can resolve a manifest from a digest only, + // so skip name validation and return the scheme and + // digest. + return scheme, names.Name{}, d, nil + } + } + + maskName := defaultMask + if mask != "" { + maskName = names.Parse(mask) + } + n := names.Merge(names.Parse(name), maskName) + if !n.IsFullyQualified() { + return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) + } + return scheme, n, d, nil +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index af898c268..20a1f1593 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } } - rc := &Registry{ + r := &Registry{ HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, } link := func(name string, manifest string) { - _, n, _, err := parseName(name, rc.NameMask) + _, n, _, err := parseName(name, r.Mask) if err != nil { panic(err) } @@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) link("invalid", "!!!!!") - return rc, c + return r, c } func okHandler(w http.ResponseWriter, r *http.Request) { @@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { return d } -func TestRegistryPushInvalidNames(t *testing.T) { - rc, c := newClient(t, nil) - - cases := []struct { - name string - err error - }{ - {"", ErrNameInvalid}, - {"@", ErrNameInvalid}, - {"@x", blob.ErrInvalidDigest}, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - // Create a new registry and push a new image. - err := rc.Push(t.Context(), c, tt.name, nil) - if !errors.Is(err, tt.err) { - t.Errorf("err = %v; want %v", err, tt.err) - } - }) - } -} - func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} return WithTrace(ctx, t), t @@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) { })) defer s.Close() - const name = "ollama.com/library/insecure" + const name = "library/insecure" var rc Registry url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) @@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) { }) } } + +// TestParseNameErrors tests that parseName returns errors messages with enough +// detail for users to debug naming issues they may encounter. Previous to this +// test, the error messages were not very helpful and each problem was reported +// as the same message. +// +// It is only for testing error messages, not that all invalids and valids are +// covered. Those are in other tests for names.Name and blob.Digest. +func TestParseNameErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"x", nil, ""}, + {"x@", nil, ""}, + + {"", ErrNameInvalid, `invalid or missing name: ""`}, + {"://", ErrNameInvalid, `invalid or missing name: "://"`}, + {"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`}, + + {"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + {"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + } + + for _, tt := range cases { + _, _, _, err := parseName(tt.name, DefaultMask) + if !errors.Is(err, tt.err) { + t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) + } + if err != nil && !strings.Contains(err.Error(), tt.want) { + t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want) + } + } +} diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go index 361cce76f..f0a1185dc 100644 --- a/server/internal/internal/names/name.go +++ b/server/internal/internal/names/name.go @@ -8,7 +8,7 @@ import ( "github.com/ollama/ollama/server/internal/internal/stringsx" ) -const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: +const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // //: type Name struct { // Make incomparable to enfoce use of Compare / Equal for @@ -25,19 +25,12 @@ type Name struct { // format of a valid name string is: // // s: -// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } // { host } "/" { namespace } "/" { model } ":" { tag } -// { host } "/" { namespace } "/" { model } "@" { digest } // { host } "/" { namespace } "/" { model } -// { namespace } "/" { model } ":" { tag } "@" { digest } // { namespace } "/" { model } ":" { tag } -// { namespace } "/" { model } "@" { digest } // { namespace } "/" { model } -// { model } ":" { tag } "@" { digest } // { model } ":" { tag } -// { model } "@" { digest } // { model } -// "@" { digest } // host: // pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* // length: [1, 350] @@ -50,9 +43,6 @@ type Name struct { // tag: // pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* // length: [1, 80] -// digest: -// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* -// length: [1, 80] // // The name returned is not guaranteed to be valid. If it is not valid, the // field values are left in an undefined state. Use [Name.IsValid] to check @@ -82,23 +72,17 @@ func Parse(s string) Name { } } -// ParseExtended parses and returns any scheme, Name, and digest from from s in -// the the form [scheme://][name][@digest]. All parts are optional. -// -// If the scheme is present, it must be followed by "://". The digest is -// prefixed by "@" and comes after the name. The name is parsed using [Parse]. -// -// The scheme and digest are stripped before the name is parsed by [Parse]. -// -// For convience, the scheme is never empty. If the scheme is not present, the -// returned scheme is "https". +// Split splits an extended name string into its scheme, name, and digest +// parts. // // Examples: // // http://ollama.com/bmizerany/smol:latest@digest // https://ollama.com/bmizerany/smol:latest // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. -func ParseExtended(s string) (scheme string, _ Name, digest string) { +// model@digest +// @digest +func Split(s string) (scheme, name, digest string) { i := strings.Index(s, "://") if i >= 0 { scheme = s[:i] @@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) { digest = s[i+1:] s = s[:i] } - return scheme, Parse(s), digest -} - -func FormatExtended(scheme string, n Name, digest string) string { - var b strings.Builder - if scheme != "" { - b.WriteString(scheme) - b.WriteString("://") - } - b.WriteString(n.String()) - if digest != "" { - b.WriteByte('@') - b.WriteString(digest) - } - return b.String() + return scheme, s, digest } // Merge merges two names into a single name. Non-empty host, namespace, and @@ -141,39 +111,68 @@ func Merge(a, b Name) Name { // IsValid returns true if the name is valid. func (n Name) IsValid() bool { - if n.h != "" && !isValidHost(n.h) { + if n.h != "" && !isValidPart(partHost, n.h) { return false } - if n.n != "" && !isValidNamespace(n.n) { + if n.n != "" && !isValidPart(partNamespace, n.n) { return false } - if n.m != "" && !isValidModel(n.m) { + if n.t != "" && !isValidPart(partTag, n.t) { return false } - if n.t != "" && !isValidTag(n.t) { - return false - } - return true + + // at bare minimum, model must be present and valid + return n.m != "" && isValidPart(partModel, n.m) } func (n Name) IsFullyQualified() bool { return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" } -func isValidHost(_ string) bool { - return true // TODO: implement +const ( + partHost = iota + partNamespace + partModel + partTag +) + +func isValidPart(kind int, s string) bool { + maxlen := 80 + if kind == partHost { + maxlen = 350 + } + if len(s) > maxlen { + return false + } + + for i := range s { + if i == 0 { + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + continue + } + switch s[i] { + case '_', '-': + case '.': + if kind == partNamespace { + return false + } + case ':': + if kind != partHost { + return false + } + default: + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + } + } + return true } -func isValidNamespace(_ string) bool { - return true // TODO: implement -} - -func isValidModel(_ string) bool { - return true // TODO: implement -} - -func isValidTag(_ string) bool { - return true // TODO: implement +func isAlphanumericOrUnderscore(c byte) bool { + return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_' } func (n Name) Host() string { return n.h } diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go index 760fec5fa..e3dc5fe3c 100644 --- a/server/internal/internal/names/name_test.go +++ b/server/internal/internal/names/name_test.go @@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) { } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - scheme, name, digest := ParseExtended(tt.in) - if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + scheme, name, digest := Split(tt.in) + n := Parse(name) + if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest { t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) } - - // Round trip - if got := FormatExtended(scheme, name, digest); got != tt.in { - t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) - } }) } } @@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) { junkName = Parse("h/n/m:t") } } + +const ( + part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888" + part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333" +) + +var testCases = map[string]bool{ // name -> valid + "": false, + + "_why/_the/_lucky:_stiff": true, + + // minimal + "h/n/m:t": true, + + "host/namespace/model:tag": true, + "host/namespace/model": true, + "namespace/model": true, + "model": true, + + // long (but valid) + part80 + "/" + part80 + "/" + part80 + ":" + part80: true, + part350 + "/" + part80 + "/" + part80 + ":" + part80: true, + + // too long + part80 + "/" + part80 + "/" + part80 + ":" + part350: false, + "x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false, + + "h/nn/mm:t": true, // bare minimum part sizes + + // unqualified + "m": true, + "n/m:": true, + "h/n/m": true, + "@t": false, + "m@d": false, + + // invalids + "^": false, + "mm:": true, + "/nn/mm": true, + "//": false, // empty model + "//mm": true, + "hh//": false, // empty model + "//mm:@": false, + "00@": false, + "@": false, + + // not starting with alphanum + "-hh/nn/mm:tt": false, + "hh/-nn/mm:tt": false, + "hh/nn/-mm:tt": false, + "hh/nn/mm:-tt": false, + + // smells like a flag + "-h": false, + + // hosts + "host:https/namespace/model:tag": true, + + // colon in non-host part before tag + "host/name:space/model:tag": false, +} + +func TestParseNameValidation(t *testing.T) { + for s, valid := range testCases { + got := Parse(s) + if got.IsValid() != valid { + t.Logf("got: %v", got) + t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid()) + } + } +} diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 8eb6daf89..6ea590a70 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { return err } if !ok { - return &serverError{404, "manifest_not_found", "manifest not found"} + return &serverError{404, "not_found", "model not found"} } return nil } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 22267ba7d..7ba13d501 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) { got = s.send(t, "DELETE", "/api/delete", ``) checkErrorResponse(t, got, 400, "bad_request", "empty request body") - got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`) - checkErrorResponse(t, got, 404, "manifest_not_found", "not found") - got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`) - checkErrorResponse(t, got, 400, "bad_request", "invalid name") + checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body checkErrorResponse(t, got, 404, "not_found", "not found")