x/model: replace part fields with array of parts

This makes building strings and reasoning about parts easier.
This commit is contained in:
Blake Mizerany 2024-04-06 13:37:33 -07:00
parent 45d8d22785
commit 14a6f85e9e
2 changed files with 136 additions and 146 deletions

View File

@ -1,9 +1,11 @@
package model package model
import ( import (
"bytes"
"cmp" "cmp"
"errors" "errors"
"hash/maphash" "hash/maphash"
"io"
"iter" "iter"
"log/slog" "log/slog"
"slices" "slices"
@ -41,12 +43,15 @@ func (k NamePart) String() string {
// Levels of concreteness // Levels of concreteness
const ( const (
Invalid NamePart = iota Host NamePart = iota
Host
Namespace Namespace
Model Model
Tag Tag
Build Build
NumParts = Build + 1
Invalid = NamePart(-1)
) )
// Name is an opaque reference to a model. It holds the parts of a model // Name is an opaque reference to a model. It holds the parts of a model
@ -85,12 +90,7 @@ const (
// To update parts of a Name with defaults, use [Fill]. // To update parts of a Name with defaults, use [Fill].
type Name struct { type Name struct {
_ structs.Incomparable _ structs.Incomparable
parts [NumParts]string
host string
namespace string
model string
tag string
build string
} }
// ParseName parses s into a Name. The input string must be a valid string // ParseName parses s into a Name. The input string must be a valid string
@ -127,20 +127,10 @@ type Name struct {
func ParseName(s string) Name { func ParseName(s string) Name {
var r Name var r Name
for kind, part := range NameParts(s) { for kind, part := range NameParts(s) {
switch kind { if kind == Invalid {
case Host:
r.host = part
case Namespace:
r.namespace = part
case Model:
r.model = part
case Tag:
r.tag = part
case Build:
r.build = part
case Invalid:
return Name{} return Name{}
} }
r.parts[kind] = part
} }
if !r.Valid() { if !r.Valid() {
return Name{} return Name{}
@ -152,18 +142,16 @@ func ParseName(s string) Name {
// //
// The returned Name will only be valid if dst is valid. // The returned Name will only be valid if dst is valid.
func Fill(dst, src Name) Name { func Fill(dst, src Name) Name {
return Name{ var r Name
model: cmp.Or(dst.model, src.model), for i := range r.parts {
host: cmp.Or(dst.host, src.host), r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
namespace: cmp.Or(dst.namespace, src.namespace),
tag: cmp.Or(dst.tag, src.tag),
build: cmp.Or(dst.build, src.build),
} }
return r
} }
// WithBuild returns a copy of r with the build set to the given string. // WithBuild returns a copy of r with the build set to the given string.
func (r Name) WithBuild(build string) Name { func (r Name) WithBuild(build string) Name {
r.build = build r.parts[Build] = build
return r return r
} }
@ -188,9 +176,15 @@ func (r Name) MapHash() uint64 {
return h.Sum64() return h.Sum64()
} }
func (r Name) slice(from, to NamePart) Name {
var v Name
copy(v.parts[from:to+1], r.parts[from:to+1])
return v
}
// DisplayModel returns the a display string composed of the model only. // DisplayModel returns the a display string composed of the model only.
func (r Name) DisplayModel() string { func (r Name) DisplayModel() string {
return r.model return r.parts[Model]
} }
// DisplayFullest returns the fullest possible display string in form: // DisplayFullest returns the fullest possible display string in form:
@ -202,12 +196,7 @@ func (r Name) DisplayModel() string {
// It does not include the build part. For the fullest possible display // It does not include the build part. For the fullest possible display
// string with the build, use [Name.String]. // string with the build, use [Name.String].
func (r Name) DisplayFullest() string { func (r Name) DisplayFullest() string {
return (Name{ return r.slice(Host, Tag).String()
host: r.host,
namespace: r.namespace,
model: r.model,
tag: r.tag,
}).String()
} }
// DisplayShort returns the fullest possible display string in form: // DisplayShort returns the fullest possible display string in form:
@ -216,10 +205,7 @@ func (r Name) DisplayFullest() string {
// //
// If any part is missing, it is omitted from the display string. // If any part is missing, it is omitted from the display string.
func (r Name) DisplayShort() string { func (r Name) DisplayShort() string {
return (Name{ return r.slice(Model, Tag).String()
model: r.model,
tag: r.tag,
}).String()
} }
// DisplayLong returns the fullest possible display string in form: // DisplayLong returns the fullest possible display string in form:
@ -228,11 +214,36 @@ func (r Name) DisplayShort() string {
// //
// If any part is missing, it is omitted from the display string. // If any part is missing, it is omitted from the display string.
func (r Name) DisplayLong() string { func (r Name) DisplayLong() string {
return (Name{ return r.slice(Namespace, Tag).String()
namespace: r.namespace, }
model: r.model,
tag: r.tag, var seps = [...]string{
}).String() Host: "/",
Namespace: "/",
Model: ":",
Tag: "+",
Build: "",
}
func (r Name) WriteTo(w io.Writer) (n int64, err error) {
for i := range r.parts {
if r.parts[i] == "" {
continue
}
if n > 0 {
n1, err := io.WriteString(w, seps[i-1])
n += int64(n1)
if err != nil {
return n, err
}
}
n1, err := io.WriteString(w, r.parts[i])
n += int64(n1)
if err != nil {
return n, err
}
}
return n, nil
} }
var builderPool = sync.Pool{ var builderPool = sync.Pool{
@ -241,6 +252,9 @@ var builderPool = sync.Pool{
}, },
} }
// TODO(bmizerany): Add WriteTo and use in String and MarshalText with
// strings.Builder and bytes.Buffer, respectively.
// String returns the fullest possible display string in form: // String returns the fullest possible display string in form:
// //
// <host>/<namespace>/<model>:<tag>+<build> // <host>/<namespace>/<model>:<tag>+<build>
@ -251,33 +265,10 @@ var builderPool = sync.Pool{
// [Name.DisplayFullest]. // [Name.DisplayFullest].
func (r Name) String() string { func (r Name) String() string {
b := builderPool.Get().(*strings.Builder) b := builderPool.Get().(*strings.Builder)
b.Reset()
defer builderPool.Put(b) defer builderPool.Put(b)
b.Grow(0 + b.Reset()
len(r.host) + b.Grow(50) // arbitrarily long enough for most names
len(r.namespace) + _, _ = r.WriteTo(b)
len(r.model) +
len(r.tag) +
len(r.build) +
4, // 4 possible separators
)
if r.host != "" {
b.WriteString(r.host)
b.WriteString("/")
}
if r.namespace != "" {
b.WriteString(r.namespace)
b.WriteString("/")
}
b.WriteString(r.model)
if r.tag != "" {
b.WriteString(":")
b.WriteString(r.tag)
}
if r.build != "" {
b.WriteString("+")
b.WriteString(r.build)
}
return b.String() return b.String()
} }
@ -286,13 +277,11 @@ func (r Name) String() string {
// returns a string that includes all parts of the Name, with missing parts // returns a string that includes all parts of the Name, with missing parts
// replaced with a ("?"). // replaced with a ("?").
func (r Name) GoString() string { func (r Name) GoString() string {
return (Name{ var v Name
host: cmp.Or(r.host, "?"), for i := range r.parts {
namespace: cmp.Or(r.namespace, "?"), v.parts[i] = cmp.Or(r.parts[i], "?")
model: cmp.Or(r.model, "?"), }
tag: cmp.Or(r.tag, "?"), return v.String()
build: cmp.Or(r.build, "?"),
}).String()
} }
// LogValue implements slog.Valuer. // LogValue implements slog.Valuer.
@ -300,18 +289,25 @@ func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString()) return slog.StringValue(r.GoString())
} }
// MarshalText implements encoding.TextMarshaler. var bufPool = sync.Pool{
func (r Name) MarshalText() ([]byte, error) { New: func() interface{} {
// unsafeBytes is safe here because we gurantee that the string is return new(bytes.Buffer)
// never used after this function returns. },
//
// TODO: We can remove this if https://github.com/golang/go/issues/62384
// lands.
return unsafeBytes(r.String()), nil
} }
func unsafeBytes(s string) []byte { // MarshalText implements encoding.TextMarshaler.
return *(*[]byte)(unsafe.Pointer(&s)) func (r Name) MarshalText() ([]byte, error) {
b := bufPool.Get().(*bytes.Buffer)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
defer bufPool.Put(b)
_, err := r.WriteTo(b)
if err != nil {
return nil, err
}
// TODO: We can remove this alloc if/when
// https://github.com/golang/go/issues/62384 lands.
return b.Bytes(), nil
} }
// UnmarshalText implements encoding.TextUnmarshaler. // UnmarshalText implements encoding.TextUnmarshaler.
@ -329,13 +325,13 @@ func unsafeString(b []byte) string {
// Complete reports whether the Name is fully qualified. That is it has a // Complete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build. // domain, namespace, name, tag, and build.
func (r Name) Complete() bool { func (r Name) Complete() bool {
return !slices.Contains(r.Parts(), "") return !slices.Contains(r.parts[:], "")
} }
// CompleteNoBuild is like [Name.Complete] but it does not require the // CompleteNoBuild is like [Name.Complete] but it does not require the
// build part to be present. // build part to be present.
func (r Name) CompleteNoBuild() bool { func (r Name) CompleteNoBuild() bool {
return !slices.Contains(r.Parts()[:4], "") return !slices.Contains(r.parts[:Tag], "")
} }
// EqualFold reports whether r and o are equivalent model names, ignoring // EqualFold reports whether r and o are equivalent model names, ignoring
@ -350,27 +346,23 @@ func (r Name) EqualFold(o Name) bool {
// //
// For simple equality checks, use [Name.EqualFold]. // For simple equality checks, use [Name.EqualFold].
func (r Name) CompareFold(o Name) int { func (r Name) CompareFold(o Name) int {
return cmp.Or( for i := range r.parts {
compareFold(r.host, o.host), if n := compareFold(r.parts[i], o.parts[i]); n != 0 {
compareFold(r.namespace, o.namespace), return n
compareFold(r.model, o.model), }
compareFold(r.tag, o.tag), }
compareFold(r.build, o.build), return 0
)
} }
func compareFold(a, b string) int { func compareFold(a, b string) int {
// fast-path for unequal lengths // fast-path for unequal lengths
if n := cmp.Compare(len(a), len(b)); n != 0 {
return n
}
for i := 0; i < len(a) && i < len(b); i++ { for i := 0; i < len(a) && i < len(b); i++ {
ca, cb := downcase(a[i]), downcase(b[i]) ca, cb := downcase(a[i]), downcase(b[i])
if n := cmp.Compare(ca, cb); n != 0 { if n := cmp.Compare(ca, cb); n != 0 {
return n return n
} }
} }
return 0 return cmp.Compare(len(a), len(b))
} }
func downcase(c byte) byte { func downcase(c byte) byte {
@ -387,13 +379,7 @@ func downcase(c byte) byte {
// //
// The length of the returned slice is always 5. // The length of the returned slice is always 5.
func (r Name) Parts() []string { func (r Name) Parts() []string {
return []string{ return slices.Clone(r.parts[:])
r.host,
r.namespace,
r.model,
r.tag,
r.build,
}
} }
// Parts returns a sequence of the parts of a Name string from most specific // Parts returns a sequence of the parts of a Name string from most specific
@ -492,7 +478,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
func (r Name) Valid() bool { func (r Name) Valid() bool {
// Parts ensures we only have valid parts, so no need to validate // Parts ensures we only have valid parts, so no need to validate
// them here, only check if we have a name or not. // them here, only check if we have a name or not.
return r.model != "" return r.parts[Model] != ""
} }
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
@ -520,3 +506,10 @@ func isValidByte(kind NamePart, c byte) bool {
} }
return false return false
} }
func sumLens(a []string) (sum int) {
for _, n := range a {
sum += len(n)
}
return
}

View File

@ -11,7 +11,21 @@ import (
"testing" "testing"
) )
var testNames = map[string]Name{ type fields struct {
host, namespace, model, tag, build string
}
func fieldsFromName(p Name) fields {
return fields{
host: p.parts[Host],
namespace: p.parts[Namespace],
model: p.parts[Model],
tag: p.parts[Tag],
build: p.parts[Build],
}
}
var testNames = map[string]fields{
"mistral:latest": {model: "mistral", tag: "latest"}, "mistral:latest": {model: "mistral", tag: "latest"},
"mistral": {model: "mistral"}, "mistral": {model: "mistral"},
"mistral:30B": {model: "mistral", tag: "30B"}, "mistral:30B": {model: "mistral", tag: "30B"},
@ -23,7 +37,7 @@ var testNames = map[string]Name{
"llama2": {model: "llama2"}, "llama2": {model: "llama2"},
"user/model": {namespace: "user", model: "model"}, "user/model": {namespace: "user", model: "model"},
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"}, "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
"example.com/ns/mistral:7b+x": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
// preserves case for build // preserves case for build
"x+b": {model: "x", build: "b"}, "x+b": {model: "x", build: "b"},
@ -73,7 +87,7 @@ func TestNameParts(t *testing.T) {
} }
func TestNamePartString(t *testing.T) { func TestNamePartString(t *testing.T) {
if g := NamePart(-1).String(); g != "Unknown" { if g := NamePart(-2).String(); g != "Unknown" {
t.Errorf("Unknown part = %q; want %q", g, "Unknown") t.Errorf("Unknown part = %q; want %q", g, "Unknown")
} }
for kind, name := range kindNames { for kind, name := range kindNames {
@ -83,34 +97,6 @@ func TestNamePartString(t *testing.T) {
} }
} }
func TestPartTooLong(t *testing.T) {
for i := Host; i <= Build; i++ {
t.Run(i.String(), func(t *testing.T) {
var p Name
switch i {
case Host:
p.host = strings.Repeat("a", MaxNamePartLen+1)
case Namespace:
p.namespace = strings.Repeat("a", MaxNamePartLen+1)
case Model:
p.model = strings.Repeat("a", MaxNamePartLen+1)
case Tag:
p.tag = strings.Repeat("a", MaxNamePartLen+1)
case Build:
p.build = strings.Repeat("a", MaxNamePartLen+1)
}
s := strings.Trim(p.String(), "+:/")
if len(s) != MaxNamePartLen+1 {
t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
t.Logf("String() = %q", s)
}
if ParseName(p.String()).Valid() {
t.Errorf("Valid(%q) = true; want false", p)
}
})
}
}
func TestParseName(t *testing.T) { func TestParseName(t *testing.T) {
for baseName, want := range testNames { for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} { for _, prefix := range []string{"", "https://", "http://"} {
@ -119,19 +105,20 @@ func TestParseName(t *testing.T) {
s := prefix + baseName s := prefix + baseName
t.Run(s, func(t *testing.T) { t.Run(s, func(t *testing.T) {
got := ParseName(s) name := ParseName(s)
if !got.EqualFold(want) { got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want) t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
} }
// test round-trip // test round-trip
if !ParseName(got.String()).EqualFold(got) { if !ParseName(name.String()).EqualFold(name) {
t.Errorf("String() = %s; want %s", got.String(), baseName) t.Errorf("String() = %s; want %s", name.String(), baseName)
} }
if got.Valid() && got.model == "" { if name.Valid() && name.DisplayModel() == "" {
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model) t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
} else if !got.Valid() && got.DisplayModel() != "" { } else if !name.Valid() && name.DisplayModel() != "" {
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model) t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
} }
}) })
@ -405,7 +392,7 @@ func TestNameTextMarshal(t *testing.T) {
t.Fatal("MarshalText() = 0; want non-zero") t.Fatal("MarshalText() = 0; want non-zero")
} }
}) })
if allocs > 1 { if allocs > 0 {
// TODO: Update when/if this lands: // TODO: Update when/if this lands:
// https://github.com/golang/go/issues/62384 // https://github.com/golang/go/issues/62384
// //
@ -414,6 +401,16 @@ func TestNameTextMarshal(t *testing.T) {
} }
} }
func TestNameStringAllocs(t *testing.T) {
name := ParseName("example.com/ns/mistral:latest+Q4_0")
allocs := testing.AllocsPerRun(1000, func() {
keep(name.String())
})
if allocs > 1 {
t.Errorf("String allocs = %v; want 0", allocs)
}
}
func ExampleFill() { func ExampleFill() {
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0") defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
r := Fill(ParseName("mistral"), defaults) r := Fill(ParseName("mistral"), defaults)