x/model: make equality checks case-insensitive
This commit is contained in:
parent
92b7e40fde
commit
bfe89d6fa0
@ -2,9 +2,12 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"hash/maphash"
|
||||||
"iter"
|
"iter"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/types/structs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MaxNameLength = 255
|
const MaxNameLength = 255
|
||||||
@ -36,6 +39,8 @@ var kindNames = map[NamePart]string{
|
|||||||
//
|
//
|
||||||
// Users or Name must check Valid before using it.
|
// Users or Name must check Valid before using it.
|
||||||
type Name struct {
|
type Name struct {
|
||||||
|
_ structs.Incomparable
|
||||||
|
|
||||||
host string
|
host string
|
||||||
namespace string
|
namespace string
|
||||||
model string
|
model string
|
||||||
@ -43,6 +48,27 @@ type Name struct {
|
|||||||
build string
|
build string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mapHashSeed = maphash.MakeSeed()
|
||||||
|
|
||||||
|
// MapHash returns a case insensitive hash for use in maps and equality
|
||||||
|
// checks. For a convienent way to compare names, use [EqualFold].
|
||||||
|
func (r Name) MapHash() uint64 {
|
||||||
|
// correctly hash the parts with case insensitive comparison
|
||||||
|
var h maphash.Hash
|
||||||
|
h.SetSeed(mapHashSeed)
|
||||||
|
for _, part := range r.Parts() {
|
||||||
|
// downcase the part for hashing
|
||||||
|
for i := range part {
|
||||||
|
c := part[i]
|
||||||
|
if c >= 'A' && c <= 'Z' {
|
||||||
|
c = c - 'A' + 'a'
|
||||||
|
}
|
||||||
|
h.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h.Sum64()
|
||||||
|
}
|
||||||
|
|
||||||
// Format returns a string representation of the ref with the given
|
// Format returns a string representation of the ref with the given
|
||||||
// concreteness. If a part is missing, it is replaced with a loud
|
// concreteness. If a part is missing, it is replaced with a loud
|
||||||
// placeholder.
|
// placeholder.
|
||||||
@ -135,6 +161,10 @@ func (r Name) Model() string { return r.model }
|
|||||||
func (r Name) Tag() string { return r.tag }
|
func (r Name) Tag() string { return r.tag }
|
||||||
func (r Name) Build() string { return r.build }
|
func (r Name) Build() string { return r.build }
|
||||||
|
|
||||||
|
func (r Name) EqualFold(o Name) bool {
|
||||||
|
return r.MapHash() == o.MapHash()
|
||||||
|
}
|
||||||
|
|
||||||
// ParseName parses s into a Name. The input string must be a valid form of
|
// ParseName parses s into a Name. The input string must be a valid form of
|
||||||
// a model name in the form:
|
// a model name in the form:
|
||||||
//
|
//
|
||||||
|
@ -49,21 +49,21 @@ func TestNameParts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParseName(t *testing.T) {
|
func TestParseName(t *testing.T) {
|
||||||
for s, want := range testNames {
|
for baseName, want := range testNames {
|
||||||
for _, prefix := range []string{"", "https://", "http://"} {
|
for _, prefix := range []string{"", "https://", "http://"} {
|
||||||
// We should get the same results with or without the
|
// We should get the same results with or without the
|
||||||
// http(s) prefixes
|
// http(s) prefixes
|
||||||
s := prefix + s
|
s := prefix + baseName
|
||||||
|
|
||||||
t.Run(s, func(t *testing.T) {
|
t.Run(s, func(t *testing.T) {
|
||||||
got := ParseName(s)
|
got := ParseName(s)
|
||||||
if got != want {
|
if !got.EqualFold(want) {
|
||||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test round-trip
|
// test round-trip
|
||||||
if ParseName(got.String()) != got {
|
if !ParseName(got.String()).EqualFold(got) {
|
||||||
t.Errorf("String() = %s; want %s", got.String(), s)
|
t.Errorf("String() = %s; want %s", got.String(), baseName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got.Valid() && got.Model() == "" {
|
if got.Valid() && got.Model() == "" {
|
||||||
@ -190,7 +190,7 @@ func FuzzParseName(f *testing.F) {
|
|||||||
f.Fuzz(func(t *testing.T, s string) {
|
f.Fuzz(func(t *testing.T, s string) {
|
||||||
r0 := ParseName(s)
|
r0 := ParseName(s)
|
||||||
if !r0.Valid() {
|
if !r0.Valid() {
|
||||||
if r0 != (Name{}) {
|
if !r0.EqualFold(Name{}) {
|
||||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||||
}
|
}
|
||||||
t.Skipf("invalid path: %q", s)
|
t.Skipf("invalid path: %q", s)
|
||||||
@ -207,7 +207,7 @@ func FuzzParseName(f *testing.F) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r1 := ParseName(r0.String())
|
r1 := ParseName(r0.String())
|
||||||
if r0 != r1 {
|
if !r0.EqualFold(r1) {
|
||||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user