server/.../safetensors: fix offsets and include all model parts (#9427)
Also, require the -as flag to be set when importing a model. This prevents the confusing error message "invalid name". Also, allow short names to be used when importing a model and auto-complete the name with the default mask.
This commit is contained in:
parent
b42aba40ed
commit
eed11ded30
@ -147,14 +147,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultName = func() names.Name {
|
const DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||||
n := names.Parse("registry.ollama.ai/library/_:latest")
|
|
||||||
|
var defaultMask = func() names.Name {
|
||||||
|
n := names.Parse(DefaultMask)
|
||||||
if !n.IsFullyQualified() {
|
if !n.IsFullyQualified() {
|
||||||
panic("default name is not fully qualified")
|
panic("default mask is not fully qualified")
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// CompleteName returns a fully qualified name by merging the given name with
|
||||||
|
// the default mask. If the name is already fully qualified, it is returned
|
||||||
|
// unchanged.
|
||||||
|
func CompleteName(name string) string {
|
||||||
|
return names.Merge(names.Parse(name), defaultMask).String()
|
||||||
|
}
|
||||||
|
|
||||||
// Registry is a client for performing push and pull operations against an
|
// Registry is a client for performing push and pull operations against an
|
||||||
// Ollama registry.
|
// Ollama registry.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
@ -249,7 +258,7 @@ type PushParams struct {
|
|||||||
//
|
//
|
||||||
// The scheme is returned as provided by [names.ParseExtended].
|
// The scheme is returned as provided by [names.ParseExtended].
|
||||||
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
|
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
|
||||||
maskName := defaultName
|
maskName := defaultMask
|
||||||
if mask != "" {
|
if mask != "" {
|
||||||
maskName = names.Parse(mask)
|
maskName = names.Parse(mask)
|
||||||
if !maskName.IsFullyQualified() {
|
if !maskName.IsFullyQualified() {
|
||||||
|
@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||||
|
|
||||||
// TODO(bmizerany): do something with metadata? This could be another
|
// TODO(bmizerany): do something with metadata? This could be another
|
||||||
// header read if needed. We also need to figure out if the metadata is
|
// header read if needed. We also need to figure out if the metadata is
|
||||||
// present in only one .safetensors file or if each file may have their
|
// present in only one .safetensors file or if each file may have their
|
||||||
@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
|||||||
|
|
||||||
tt := make([]*Tensor, 0, len(raws))
|
tt := make([]*Tensor, 0, len(raws))
|
||||||
for name, raw := range raws {
|
for name, raw := range raws {
|
||||||
if !strings.HasPrefix(name, "model.layer") {
|
if name == "__metadata__" {
|
||||||
|
// TODO(bmizerany): do something with metadata?
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var v struct {
|
var v struct {
|
||||||
@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
|||||||
|
|
||||||
// TODO(bmizerany): after collecting, validate all offests make
|
// TODO(bmizerany): after collecting, validate all offests make
|
||||||
// tensors contiguous?
|
// tensors contiguous?
|
||||||
begin, end := v.Offsets[0], v.Offsets[1]
|
begin := endOfHeader + v.Offsets[0]
|
||||||
|
end := endOfHeader + v.Offsets[1]
|
||||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
|||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
flag.Parse(args)
|
flag.Parse(args)
|
||||||
|
if *flagAs == "" {
|
||||||
|
return fmt.Errorf("missing -as flag")
|
||||||
|
}
|
||||||
|
as := ollama.CompleteName(*flagAs)
|
||||||
|
|
||||||
dir := cmp.Or(flag.Arg(0), ".")
|
dir := cmp.Or(flag.Arg(0), ".")
|
||||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||||
@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.Link(*flagAs, d)
|
return c.Link(as, d)
|
||||||
}()
|
}()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
|||||||
writeProgress()
|
writeProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
writeProgress()
|
writeProgress()
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Successfully imported", as)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user