x/model: add MarshalText and UnmarshalText to Name
This commit is contained in:
parent
e201627c63
commit
45d8d22785
@ -8,6 +8,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/types/structs"
|
"github.com/ollama/ollama/x/types/structs"
|
||||||
)
|
)
|
||||||
@ -233,6 +235,12 @@ func (r Name) DisplayLong() string {
|
|||||||
}).String()
|
}).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var builderPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &strings.Builder{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// String returns the fullest possible display string in form:
|
// String returns the fullest possible display string in form:
|
||||||
//
|
//
|
||||||
// <host>/<namespace>/<model>:<tag>+<build>
|
// <host>/<namespace>/<model>:<tag>+<build>
|
||||||
@ -242,7 +250,17 @@ func (r Name) DisplayLong() string {
|
|||||||
// For the fullest possible display string without the build, use
|
// For the fullest possible display string without the build, use
|
||||||
// [Name.DisplayFullest].
|
// [Name.DisplayFullest].
|
||||||
func (r Name) String() string {
|
func (r Name) String() string {
|
||||||
var b strings.Builder
|
b := builderPool.Get().(*strings.Builder)
|
||||||
|
b.Reset()
|
||||||
|
defer builderPool.Put(b)
|
||||||
|
b.Grow(0 +
|
||||||
|
len(r.host) +
|
||||||
|
len(r.namespace) +
|
||||||
|
len(r.model) +
|
||||||
|
len(r.tag) +
|
||||||
|
len(r.build) +
|
||||||
|
4, // 4 possible separators
|
||||||
|
)
|
||||||
if r.host != "" {
|
if r.host != "" {
|
||||||
b.WriteString(r.host)
|
b.WriteString(r.host)
|
||||||
b.WriteString("/")
|
b.WriteString("/")
|
||||||
@ -282,6 +300,32 @@ func (r Name) LogValue() slog.Value {
|
|||||||
return slog.StringValue(r.GoString())
|
return slog.StringValue(r.GoString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalText implements encoding.TextMarshaler.
|
||||||
|
func (r Name) MarshalText() ([]byte, error) {
|
||||||
|
// unsafeBytes is safe here because we gurantee that the string is
|
||||||
|
// never used after this function returns.
|
||||||
|
//
|
||||||
|
// TODO: We can remove this if https://github.com/golang/go/issues/62384
|
||||||
|
// lands.
|
||||||
|
return unsafeBytes(r.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsafeBytes(s string) []byte {
|
||||||
|
return *(*[]byte)(unsafe.Pointer(&s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||||
|
func (r *Name) UnmarshalText(text []byte) error {
|
||||||
|
// unsafeString is safe here because the contract of UnmarshalText
|
||||||
|
// that text belongs to us for the duration of the call.
|
||||||
|
*r = ParseName(unsafeString(text))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsafeString(b []byte) string {
|
||||||
|
return *(*string)(unsafe.Pointer(&b))
|
||||||
|
}
|
||||||
|
|
||||||
// Complete reports whether the Name is fully qualified. That is it has a
|
// Complete reports whether the Name is fully qualified. That is it has a
|
||||||
// domain, namespace, name, tag, and build.
|
// domain, namespace, name, tag, and build.
|
||||||
func (r Name) Complete() bool {
|
func (r Name) Complete() bool {
|
||||||
|
@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
@ -352,6 +353,67 @@ func TestFill(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNameTextMarshal(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||||
|
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||||
|
{"mistral:latest", "mistral:latest", nil},
|
||||||
|
{"mistral", "mistral", nil},
|
||||||
|
{"mistral:7b", "mistral:7b", nil},
|
||||||
|
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.in, func(t *testing.T) {
|
||||||
|
p := ParseName(tt.in)
|
||||||
|
got, err := p.MarshalText()
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if string(got) != tt.want {
|
||||||
|
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
|
||||||
|
var r Name
|
||||||
|
if err := r.UnmarshalText(got); err != nil {
|
||||||
|
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||||
|
}
|
||||||
|
if !r.EqualFold(p) {
|
||||||
|
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||||
|
if !name.Complete() {
|
||||||
|
// sanity check
|
||||||
|
t.Fatal("name is not complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
allocs := testing.AllocsPerRun(1000, func() {
|
||||||
|
var err error
|
||||||
|
data, err = name.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
t.Fatal("MarshalText() = 0; want non-zero")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if allocs > 1 {
|
||||||
|
// TODO: Update when/if this lands:
|
||||||
|
// https://github.com/golang/go/issues/62384
|
||||||
|
//
|
||||||
|
// Currently, the best we can do is 1 alloc.
|
||||||
|
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ExampleFill() {
|
func ExampleFill() {
|
||||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||||
r := Fill(ParseName("mistral"), defaults)
|
r := Fill(ParseName("mistral"), defaults)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user