From 14a6f85e9edb1c57c482543fffb6e605aaa23c28 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sat, 6 Apr 2024 13:37:33 -0700 Subject: [PATCH] x/model: replace part fields with array of parts This makes building strings and reasoning about parts easier. --- x/model/name.go | 209 +++++++++++++++++++++---------------------- x/model/name_test.go | 73 ++++++++------- 2 files changed, 136 insertions(+), 146 deletions(-) diff --git a/x/model/name.go b/x/model/name.go index 5ef1392d2..626c98e92 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -1,9 +1,11 @@ package model import ( + "bytes" "cmp" "errors" "hash/maphash" + "io" "iter" "log/slog" "slices" @@ -41,12 +43,15 @@ func (k NamePart) String() string { // Levels of concreteness const ( - Invalid NamePart = iota - Host + Host NamePart = iota Namespace Model Tag Build + + NumParts = Build + 1 + + Invalid = NamePart(-1) ) // Name is an opaque reference to a model. It holds the parts of a model @@ -84,13 +89,8 @@ const ( // // To update parts of a Name with defaults, use [Fill]. type Name struct { - _ structs.Incomparable - - host string - namespace string - model string - tag string - build string + _ structs.Incomparable + parts [NumParts]string } // ParseName parses s into a Name. The input string must be a valid string @@ -127,20 +127,10 @@ type Name struct { func ParseName(s string) Name { var r Name for kind, part := range NameParts(s) { - switch kind { - case Host: - r.host = part - case Namespace: - r.namespace = part - case Model: - r.model = part - case Tag: - r.tag = part - case Build: - r.build = part - case Invalid: + if kind == Invalid { return Name{} } + r.parts[kind] = part } if !r.Valid() { return Name{} @@ -152,18 +142,16 @@ func ParseName(s string) Name { // // The returned Name will only be valid if dst is valid. func Fill(dst, src Name) Name { - return Name{ - model: cmp.Or(dst.model, src.model), - host: cmp.Or(dst.host, src.host), - namespace: cmp.Or(dst.namespace, src.namespace), - tag: cmp.Or(dst.tag, src.tag), - build: cmp.Or(dst.build, src.build), + var r Name + for i := range r.parts { + r.parts[i] = cmp.Or(dst.parts[i], src.parts[i]) } + return r } // WithBuild returns a copy of r with the build set to the given string. func (r Name) WithBuild(build string) Name { - r.build = build + r.parts[Build] = build return r } @@ -188,9 +176,15 @@ func (r Name) MapHash() uint64 { return h.Sum64() } +func (r Name) slice(from, to NamePart) Name { + var v Name + copy(v.parts[from:to+1], r.parts[from:to+1]) + return v +} + // DisplayModel returns the a display string composed of the model only. func (r Name) DisplayModel() string { - return r.model + return r.parts[Model] } // DisplayFullest returns the fullest possible display string in form: @@ -202,12 +196,7 @@ func (r Name) DisplayModel() string { // It does not include the build part. For the fullest possible display // string with the build, use [Name.String]. func (r Name) DisplayFullest() string { - return (Name{ - host: r.host, - namespace: r.namespace, - model: r.model, - tag: r.tag, - }).String() + return r.slice(Host, Tag).String() } // DisplayShort returns the fullest possible display string in form: @@ -216,10 +205,7 @@ func (r Name) DisplayFullest() string { // // If any part is missing, it is omitted from the display string. func (r Name) DisplayShort() string { - return (Name{ - model: r.model, - tag: r.tag, - }).String() + return r.slice(Model, Tag).String() } // DisplayLong returns the fullest possible display string in form: @@ -228,11 +214,36 @@ func (r Name) DisplayShort() string { // // If any part is missing, it is omitted from the display string. func (r Name) DisplayLong() string { - return (Name{ - namespace: r.namespace, - model: r.model, - tag: r.tag, - }).String() + return r.slice(Namespace, Tag).String() +} + +var seps = [...]string{ + Host: "/", + Namespace: "/", + Model: ":", + Tag: "+", + Build: "", +} + +func (r Name) WriteTo(w io.Writer) (n int64, err error) { + for i := range r.parts { + if r.parts[i] == "" { + continue + } + if n > 0 { + n1, err := io.WriteString(w, seps[i-1]) + n += int64(n1) + if err != nil { + return n, err + } + } + n1, err := io.WriteString(w, r.parts[i]) + n += int64(n1) + if err != nil { + return n, err + } + } + return n, nil } var builderPool = sync.Pool{ @@ -241,6 +252,9 @@ var builderPool = sync.Pool{ }, } +// TODO(bmizerany): Add WriteTo and use in String and MarshalText with +// strings.Builder and bytes.Buffer, respectively. + // String returns the fullest possible display string in form: // // //:+ @@ -251,33 +265,10 @@ var builderPool = sync.Pool{ // [Name.DisplayFullest]. func (r Name) String() string { b := builderPool.Get().(*strings.Builder) - b.Reset() defer builderPool.Put(b) - b.Grow(0 + - len(r.host) + - len(r.namespace) + - len(r.model) + - len(r.tag) + - len(r.build) + - 4, // 4 possible separators - ) - if r.host != "" { - b.WriteString(r.host) - b.WriteString("/") - } - if r.namespace != "" { - b.WriteString(r.namespace) - b.WriteString("/") - } - b.WriteString(r.model) - if r.tag != "" { - b.WriteString(":") - b.WriteString(r.tag) - } - if r.build != "" { - b.WriteString("+") - b.WriteString(r.build) - } + b.Reset() + b.Grow(50) // arbitrarily long enough for most names + _, _ = r.WriteTo(b) return b.String() } @@ -286,13 +277,11 @@ func (r Name) String() string { // returns a string that includes all parts of the Name, with missing parts // replaced with a ("?"). func (r Name) GoString() string { - return (Name{ - host: cmp.Or(r.host, "?"), - namespace: cmp.Or(r.namespace, "?"), - model: cmp.Or(r.model, "?"), - tag: cmp.Or(r.tag, "?"), - build: cmp.Or(r.build, "?"), - }).String() + var v Name + for i := range r.parts { + v.parts[i] = cmp.Or(r.parts[i], "?") + } + return v.String() } // LogValue implements slog.Valuer. @@ -300,18 +289,25 @@ func (r Name) LogValue() slog.Value { return slog.StringValue(r.GoString()) } -// MarshalText implements encoding.TextMarshaler. -func (r Name) MarshalText() ([]byte, error) { - // unsafeBytes is safe here because we gurantee that the string is - // never used after this function returns. - // - // TODO: We can remove this if https://github.com/golang/go/issues/62384 - // lands. - return unsafeBytes(r.String()), nil +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, } -func unsafeBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer(&s)) +// MarshalText implements encoding.TextMarshaler. +func (r Name) MarshalText() ([]byte, error) { + b := bufPool.Get().(*bytes.Buffer) + b.Reset() + b.Grow(50) // arbitrarily long enough for most names + defer bufPool.Put(b) + _, err := r.WriteTo(b) + if err != nil { + return nil, err + } + // TODO: We can remove this alloc if/when + // https://github.com/golang/go/issues/62384 lands. + return b.Bytes(), nil } // UnmarshalText implements encoding.TextUnmarshaler. @@ -329,13 +325,13 @@ func unsafeString(b []byte) string { // Complete reports whether the Name is fully qualified. That is it has a // domain, namespace, name, tag, and build. func (r Name) Complete() bool { - return !slices.Contains(r.Parts(), "") + return !slices.Contains(r.parts[:], "") } // CompleteNoBuild is like [Name.Complete] but it does not require the // build part to be present. func (r Name) CompleteNoBuild() bool { - return !slices.Contains(r.Parts()[:4], "") + return !slices.Contains(r.parts[:Tag], "") } // EqualFold reports whether r and o are equivalent model names, ignoring @@ -350,27 +346,23 @@ func (r Name) EqualFold(o Name) bool { // // For simple equality checks, use [Name.EqualFold]. func (r Name) CompareFold(o Name) int { - return cmp.Or( - compareFold(r.host, o.host), - compareFold(r.namespace, o.namespace), - compareFold(r.model, o.model), - compareFold(r.tag, o.tag), - compareFold(r.build, o.build), - ) + for i := range r.parts { + if n := compareFold(r.parts[i], o.parts[i]); n != 0 { + return n + } + } + return 0 } func compareFold(a, b string) int { // fast-path for unequal lengths - if n := cmp.Compare(len(a), len(b)); n != 0 { - return n - } for i := 0; i < len(a) && i < len(b); i++ { ca, cb := downcase(a[i]), downcase(b[i]) if n := cmp.Compare(ca, cb); n != 0 { return n } } - return 0 + return cmp.Compare(len(a), len(b)) } func downcase(c byte) byte { @@ -387,13 +379,7 @@ func downcase(c byte) byte { // // The length of the returned slice is always 5. func (r Name) Parts() []string { - return []string{ - r.host, - r.namespace, - r.model, - r.tag, - r.build, - } + return slices.Clone(r.parts[:]) } // Parts returns a sequence of the parts of a Name string from most specific @@ -492,7 +478,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] { func (r Name) Valid() bool { // Parts ensures we only have valid parts, so no need to validate // them here, only check if we have a name or not. - return r.model != "" + return r.parts[Model] != "" } // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] @@ -520,3 +506,10 @@ func isValidByte(kind NamePart, c byte) bool { } return false } + +func sumLens(a []string) (sum int) { + for _, n := range a { + sum += len(n) + } + return +} diff --git a/x/model/name_test.go b/x/model/name_test.go index 4dcaa0527..55bdce8dc 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -11,7 +11,21 @@ import ( "testing" ) -var testNames = map[string]Name{ +type fields struct { + host, namespace, model, tag, build string +} + +func fieldsFromName(p Name) fields { + return fields{ + host: p.parts[Host], + namespace: p.parts[Namespace], + model: p.parts[Model], + tag: p.parts[Tag], + build: p.parts[Build], + } +} + +var testNames = map[string]fields{ "mistral:latest": {model: "mistral", tag: "latest"}, "mistral": {model: "mistral"}, "mistral:30B": {model: "mistral", tag: "30B"}, @@ -23,7 +37,7 @@ var testNames = map[string]Name{ "llama2": {model: "llama2"}, "user/model": {namespace: "user", model: "model"}, "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"}, - "example.com/ns/mistral:7b+x": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, + "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, // preserves case for build "x+b": {model: "x", build: "b"}, @@ -73,7 +87,7 @@ func TestNameParts(t *testing.T) { } func TestNamePartString(t *testing.T) { - if g := NamePart(-1).String(); g != "Unknown" { + if g := NamePart(-2).String(); g != "Unknown" { t.Errorf("Unknown part = %q; want %q", g, "Unknown") } for kind, name := range kindNames { @@ -83,34 +97,6 @@ func TestNamePartString(t *testing.T) { } } -func TestPartTooLong(t *testing.T) { - for i := Host; i <= Build; i++ { - t.Run(i.String(), func(t *testing.T) { - var p Name - switch i { - case Host: - p.host = strings.Repeat("a", MaxNamePartLen+1) - case Namespace: - p.namespace = strings.Repeat("a", MaxNamePartLen+1) - case Model: - p.model = strings.Repeat("a", MaxNamePartLen+1) - case Tag: - p.tag = strings.Repeat("a", MaxNamePartLen+1) - case Build: - p.build = strings.Repeat("a", MaxNamePartLen+1) - } - s := strings.Trim(p.String(), "+:/") - if len(s) != MaxNamePartLen+1 { - t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1) - t.Logf("String() = %q", s) - } - if ParseName(p.String()).Valid() { - t.Errorf("Valid(%q) = true; want false", p) - } - }) - } -} - func TestParseName(t *testing.T) { for baseName, want := range testNames { for _, prefix := range []string{"", "https://", "http://"} { @@ -119,19 +105,20 @@ func TestParseName(t *testing.T) { s := prefix + baseName t.Run(s, func(t *testing.T) { - got := ParseName(s) - if !got.EqualFold(want) { + name := ParseName(s) + got := fieldsFromName(name) + if got != want { t.Errorf("ParseName(%q) = %q; want %q", s, got, want) } // test round-trip - if !ParseName(got.String()).EqualFold(got) { - t.Errorf("String() = %s; want %s", got.String(), baseName) + if !ParseName(name.String()).EqualFold(name) { + t.Errorf("String() = %s; want %s", name.String(), baseName) } - if got.Valid() && got.model == "" { + if name.Valid() && name.DisplayModel() == "" { t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model) - } else if !got.Valid() && got.DisplayModel() != "" { + } else if !name.Valid() && name.DisplayModel() != "" { t.Errorf("Valid() = false; Model() = %q; want empty name", got.model) } }) @@ -405,7 +392,7 @@ func TestNameTextMarshal(t *testing.T) { t.Fatal("MarshalText() = 0; want non-zero") } }) - if allocs > 1 { + if allocs > 0 { // TODO: Update when/if this lands: // https://github.com/golang/go/issues/62384 // @@ -414,6 +401,16 @@ func TestNameTextMarshal(t *testing.T) { } } +func TestNameStringAllocs(t *testing.T) { + name := ParseName("example.com/ns/mistral:latest+Q4_0") + allocs := testing.AllocsPerRun(1000, func() { + keep(name.String()) + }) + if allocs > 1 { + t.Errorf("String allocs = %v; want 0", allocs) + } +} + func ExampleFill() { defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0") r := Fill(ParseName("mistral"), defaults)