x/model: limit part len, not entire len

Limiting the whole name length comes naturally with part name length
restrictions. This aligns with Docker's registry behavior.
This commit is contained in:
Blake Mizerany 2024-04-05 18:24:50 -07:00
parent bf8e0c09c9
commit 7c7f56a7fb
3 changed files with 123 additions and 71 deletions

View File

@ -1,14 +1,8 @@
// Package model implements the File and Name types for working with and // Package model implements the File and Name types for working with and
// representing Modelfiles and model Names. // representing Modelfiles and model Names.
// //
// The Name type is designed for safety and correctness. It is an opaque // The Name type should be used when working with model names, and the File
// reference to a model, and holds the parts of a model, casing preserved, // type should be used when working with Modelfiles.
// but is not directly comparable with other Names since model names can be
// represented with different caseing depending on the use case.
//
// Names should never be compared manually parsed. Instead, use the
// [Name.EqualFold] method to compare two names in a case-insensitive
// manner, and [ParseName] to create a Name from a string, safely.
package model package model
import ( import (

View File

@ -2,6 +2,7 @@ package model
import ( import (
"cmp" "cmp"
"errors"
"hash/maphash" "hash/maphash"
"iter" "iter"
"slices" "slices"
@ -10,21 +11,19 @@ import (
"github.com/ollama/ollama/x/types/structs" "github.com/ollama/ollama/x/types/structs"
) )
const MaxNameLength = 255 // Errors
var (
type NamePart int // ErrInvalidName is not used by this package, but is exported so that
// other packages do not need to invent their own error type when they
// Levels of concreteness // need to return an error for an invalid name.
const ( ErrIncompleteName = errors.New("incomplete model name")
Invalid NamePart = iota
Host
Namespace
Model
Tag
Build
) )
var kindNames = map[NamePart]string{ const MaxNamePartLen = 128
type NamePartKind int
var kindNames = map[NamePartKind]string{
Invalid: "Invalid", Invalid: "Invalid",
Host: "Host", Host: "Host",
Namespace: "Namespace", Namespace: "Namespace",
@ -33,12 +32,36 @@ var kindNames = map[NamePart]string{
Build: "Build", Build: "Build",
} }
// Name is an opaque reference to a model. It holds the parts of a model, func (k NamePartKind) String() string {
// casing preserved, and provides methods for comparing and manipulating return cmp.Or(kindNames[k], "!(UNKNOWN PART KIND)")
// them in a case-insensitive manner. }
// Levels of concreteness
const (
Invalid NamePartKind = iota
Host
Namespace
Model
Tag
Build
)
// Name is an opaque reference to a model. It holds the parts of a model
// with the case preserved, but is not directly comparable with other Names
// since model names can be represented with different caseing depending on
// the use case. For instance, "Mistral" and "mistral" are the same model
// but each version may have come from different sources (e.g. copied from a
// Web page, or from a file path).
// //
// To create a Name, use [ParseName]. To compare two names, use // Valid Names can ONLY be constructed by calling [ParseName].
// [Name.EqualFold]. To use a name as a key in a map, use [Name.MapHash]. //
// A Name is valid if and only if is have a valid Model part. The other parts
// are optional.
//
// A Name is considered "complete" if it has all parts present. To check if a
// Name is complete, use [Name.Complete].
//
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
// //
// The parts of a Name are: // The parts of a Name are:
// //
@ -124,7 +147,7 @@ func ParseName(s string) Name {
// Fill fills in the missing parts of dst with the parts of src. // Fill fills in the missing parts of dst with the parts of src.
// //
// Use this for merging a fully qualified ref with a partial ref, such as // Use this for merging a fully qualified Name with a partial Name, such as
// when filling in a missing parts with defaults. // when filling in a missing parts with defaults.
// //
// The returned Name will only be valid if dst is valid. // The returned Name will only be valid if dst is valid.
@ -144,6 +167,23 @@ func (r Name) WithBuild(build string) Name {
return r return r
} }
// Has reports whether the Name has the given part kind.
func (r Name) Has(kind NamePartKind) bool {
switch kind {
case Host:
return r.host != ""
case Namespace:
return r.namespace != ""
case Model:
return r.model != ""
case Tag:
return r.tag != ""
case Build:
return r.build != ""
}
return false
}
var mapHashSeed = maphash.MakeSeed() var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality // MapHash returns a case insensitive hash for use in maps and equality
@ -165,9 +205,10 @@ func (r Name) MapHash() uint64 {
return h.Sum64() return h.Sum64()
} }
// Format returns a string representation of the ref with the given func (r Name) DisplayModel() string {
// concreteness. If a part is missing, it is replaced with a loud return r.model
// placeholder. }
func (r Name) DisplayFull() string { func (r Name) DisplayFull() string {
return (Name{ return (Name{
host: cmp.Or(r.host, "!(MISSING DOMAIN)"), host: cmp.Or(r.host, "!(MISSING DOMAIN)"),
@ -178,27 +219,7 @@ func (r Name) DisplayFull() string {
}).String() }).String()
} }
func (r Name) DisplayModel() string { // DisplayCompact returns a compact display string of the Name with only the
return r.model
}
func (r Name) Has(kind NamePart) bool {
switch kind {
case Host:
return r.host != ""
case Namespace:
return r.namespace != ""
case Model:
return r.model != ""
case Tag:
return r.tag != ""
case Build:
return r.build != ""
}
return false
}
// DisplayCompact returns a compact display string of the ref with only the
// model and tag parts. // model and tag parts.
func (r Name) DisplayCompact() string { func (r Name) DisplayCompact() string {
return (Name{ return (Name{
@ -207,7 +228,7 @@ func (r Name) DisplayCompact() string {
}).String() }).String()
} }
// DisplayShort returns a short display string of the ref with only the // DisplayShort returns a short display string of the Name with only the
// model, tag, and build parts. // model, tag, and build parts.
func (r Name) DisplayShort() string { func (r Name) DisplayShort() string {
return (Name{ return (Name{
@ -217,7 +238,7 @@ func (r Name) DisplayShort() string {
}).String() }).String()
} }
// DisplayLong returns a long display string of the ref including namespace, // DisplayLong returns a long display string of the Name including namespace,
// model, tag, and build parts. // model, tag, and build parts.
func (r Name) DisplayLong() string { func (r Name) DisplayLong() string {
return (Name{ return (Name{
@ -228,7 +249,7 @@ func (r Name) DisplayLong() string {
}).String() }).String()
} }
// String returns the fully qualified ref string. // String returns the fully qualified Name string.
func (r Name) String() string { func (r Name) String() string {
var b strings.Builder var b strings.Builder
if r.host != "" { if r.host != "" {
@ -251,7 +272,7 @@ func (r Name) String() string {
return b.String() return b.String()
} }
// Complete reports whether the ref is fully qualified. That is it has a // Complete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build. // domain, namespace, name, tag, and build.
func (r Name) Complete() bool { func (r Name) Complete() bool {
return r.Valid() && !slices.Contains(r.Parts(), "") return r.Valid() && !slices.Contains(r.Parts(), "")
@ -262,7 +283,7 @@ func (r Name) Complete() bool {
// TODO(bmizerany): LogValue // TODO(bmizerany): LogValue
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough) // TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
// Parts returns the parts of the ref in order of concreteness. // Parts returns the parts of the Name in order of concreteness.
// //
// The length of the returned slice is always 5. // The length of the returned slice is always 5.
func (r Name) Parts() []string { func (r Name) Parts() []string {
@ -287,7 +308,7 @@ func (r Name) EqualFold(o Name) bool {
return r.MapHash() == o.MapHash() return r.MapHash() == o.MapHash()
} }
// Parts returns a sequence of the parts of a ref string from most specific // Parts returns a sequence of the parts of a Name string from most specific
// to least specific. // to least specific.
// //
// It normalizes the input string by removing "http://" and "https://" only. // It normalizes the input string by removing "http://" and "https://" only.
@ -295,8 +316,8 @@ func (r Name) EqualFold(o Name) bool {
// //
// As a special case, question marks are ignored so they may be used as // As a special case, question marks are ignored so they may be used as
// placeholders for missing parts in string literals. // placeholders for missing parts in string literals.
func NameParts(s string) iter.Seq2[NamePart, string] { func NameParts(s string) iter.Seq2[NamePartKind, string] {
return func(yield func(NamePart, string) bool) { return func(yield func(NamePartKind, string) bool) {
if strings.HasPrefix(s, "http://") { if strings.HasPrefix(s, "http://") {
s = s[len("http://"):] s = s[len("http://"):]
} }
@ -304,11 +325,11 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
s = s[len("https://"):] s = s[len("https://"):]
} }
if len(s) > MaxNameLength || len(s) == 0 { if len(s) > MaxNamePartLen || len(s) == 0 {
return return
} }
yieldValid := func(kind NamePart, part string) bool { yieldValid := func(kind NamePartKind, part string) bool {
if !isValidPart(kind, part) { if !isValidPart(kind, part) {
yield(Invalid, "") yield(Invalid, "")
return false return false
@ -316,8 +337,13 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
return yield(kind, part) return yield(kind, part)
} }
partLen := 0
state, j := Build, len(s) state, j := Build, len(s)
for i := len(s) - 1; i >= 0; i-- { for i := len(s) - 1; i >= 0; i-- {
if partLen++; partLen > MaxNamePartLen {
yield(Invalid, "")
return
}
switch s[i] { switch s[i] {
case '+': case '+':
switch state { switch state {
@ -325,7 +351,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Build, s[i+1:j]) { if !yieldValid(Build, s[i+1:j]) {
return return
} }
state, j = Tag, i state, j, partLen = Tag, i, 0
default: default:
yield(Invalid, "") yield(Invalid, "")
return return
@ -336,7 +362,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Tag, s[i+1:j]) { if !yieldValid(Tag, s[i+1:j]) {
return return
} }
state, j = Model, i state, j, partLen = Model, i, 0
default: default:
yield(Invalid, "") yield(Invalid, "")
return return
@ -352,7 +378,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Namespace, s[i+1:j]) { if !yieldValid(Namespace, s[i+1:j]) {
return return
} }
state, j = Host, i state, j, partLen = Host, i, 0
default: default:
yield(Invalid, "") yield(Invalid, "")
return return
@ -373,7 +399,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
} }
} }
// Valid returns true if the ref has a valid nick. To know if a ref is // Valid returns true if the Name has a valid nick. To know if a Name is
// "complete", use Complete. // "complete", use Complete.
func (r Name) Valid() bool { func (r Name) Valid() bool {
// Parts ensures we only have valid parts, so no need to validate // Parts ensures we only have valid parts, so no need to validate
@ -382,7 +408,7 @@ func (r Name) Valid() bool {
} }
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
func isValidPart(kind NamePart, s string) bool { func isValidPart(kind NamePartKind, s string) bool {
if s == "" { if s == "" {
return false return false
} }
@ -394,7 +420,7 @@ func isValidPart(kind NamePart, s string) bool {
return true return true
} }
func isValidByte(kind NamePart, c byte) bool { func isValidByte(kind NamePartKind, c byte) bool {
if kind == Namespace && c == '.' { if kind == Namespace && c == '.' {
return false return false
} }

View File

@ -52,8 +52,8 @@ var testNames = map[string]Name{
"file:///etc/passwd:latest": {}, "file:///etc/passwd:latest": {},
"file:///etc/passwd:latest+u": {}, "file:///etc/passwd:latest+u": {},
strings.Repeat("a", MaxNameLength): {model: strings.Repeat("a", MaxNameLength)}, strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
strings.Repeat("a", MaxNameLength+1): {}, strings.Repeat("a", MaxNamePartLen+1): {},
} }
func TestNameParts(t *testing.T) { func TestNameParts(t *testing.T) {
@ -64,6 +64,34 @@ func TestNameParts(t *testing.T) {
} }
} }
func TestPartTooLong(t *testing.T) {
for i := Host; i <= Build; i++ {
t.Run(i.String(), func(t *testing.T) {
var p Name
switch i {
case Host:
p.host = strings.Repeat("a", MaxNamePartLen+1)
case Namespace:
p.namespace = strings.Repeat("a", MaxNamePartLen+1)
case Model:
p.model = strings.Repeat("a", MaxNamePartLen+1)
case Tag:
p.tag = strings.Repeat("a", MaxNamePartLen+1)
case Build:
p.build = strings.Repeat("a", MaxNamePartLen+1)
}
s := strings.Trim(p.String(), "+:/")
if len(s) != MaxNamePartLen+1 {
t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
t.Logf("String() = %q", s)
}
if ParseName(p.String()).Valid() {
t.Errorf("Valid(%q) = true; want false", p)
}
})
}
}
func TestParseName(t *testing.T) { func TestParseName(t *testing.T) {
for baseName, want := range testNames { for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} { for _, prefix := range []string{"", "https://", "http://"} {
@ -210,7 +238,7 @@ func FuzzParseName(f *testing.F) {
} }
for _, p := range r0.Parts() { for _, p := range r0.Parts() {
if len(p) > MaxNameLength { if len(p) > MaxNamePartLen {
t.Errorf("part too long: %q", p) t.Errorf("part too long: %q", p)
} }
} }
@ -261,11 +289,15 @@ func ExampleFill() {
func ExampleName_MapHash() { func ExampleName_MapHash() {
m := map[uint64]bool{} m := map[uint64]bool{}
// key 1
m[ParseName("mistral:latest+q4").MapHash()] = true m[ParseName("mistral:latest+q4").MapHash()] = true
m[ParseName("miSTRal:latest+Q4").MapHash()] = true m[ParseName("miSTRal:latest+Q4").MapHash()] = true
m[ParseName("mistral:LATest+Q4").MapHash()] = true m[ParseName("mistral:LATest+Q4").MapHash()] = true
// key 2
m[ParseName("mistral:LATest").MapHash()] = true
fmt.Println(len(m)) fmt.Println(len(m))
// Output: // Output:
// 1 // 2
} }