From a4516117614467325bf5753ec2e0e2c44dc24d32 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sat, 6 Jul 2024 18:19:56 -0700 Subject: [PATCH] add adapter conversion for modelfiles --- convert/convert.go | 11 +++++++++++ convert/convert_adapter.go | 5 +++++ convert/reader_npz.go | 14 ++++++++++++++ server/model.go | 33 +++++++++++++++++++++++++++------ 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/convert/convert.go b/convert/convert.go index 4ad64d721..194f4ee6f 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -63,6 +63,17 @@ type Converter interface { writeFile(io.WriteSeeker, llm.KV, []*llm.Tensor) error } +func ConvertAdapter(d string, ws io.WriteSeeker) error { + c := &adapter{} + + ts, err := parseNPZ(d) + if err != nil { + return err + } + + return c.writeFile(ws, c.KV(nil), c.Tensors(ts)) +} + func Convert(d string, ws io.WriteSeeker) error { f, err := os.Open(filepath.Join(d, "config.json")) if err != nil { diff --git a/convert/convert_adapter.go b/convert/convert_adapter.go index f74829991..77ff8bc51 100644 --- a/convert/convert_adapter.go +++ b/convert/convert_adapter.go @@ -1,6 +1,7 @@ package convert import ( + "io" "strings" "github.com/ollama/ollama/llm" @@ -12,6 +13,10 @@ type adapter struct { var _ Converter = (*adapter)(nil) +func (p *adapter) writeFile(ws io.WriteSeeker, kv llm.KV, ts []*llm.Tensor) error { + return llm.WriteGGLA(ws, kv, ts) +} + func (p *adapter) KV(t *Tokenizer) llm.KV { // todo - need a way to pass these in kv := llm.KV{ diff --git a/convert/reader_npz.go b/convert/reader_npz.go index f2f225ae2..b9cf83eaf 100644 --- a/convert/reader_npz.go +++ b/convert/reader_npz.go @@ -18,6 +18,20 @@ type adapterTensor struct { *tensorBase } +func DetectNPZ(fn string) (bool, error) { + f, err := npz.Open(fn) + if err != nil { + return false, err + } + defer f.Close() + + if len(f.Keys()) > 0 && strings.HasSuffix(f.Keys()[0], ".npy") { + return true, nil + } + + return false, nil +} + func parseNPZ(fn string) ([]Tensor, error) { var ts []Tensor diff --git a/server/model.go b/server/model.go index a10bdd5dd..66c55d60c 100644 --- a/server/model.go +++ b/server/model.go @@ -129,14 +129,24 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) } func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { + layerType := "application/vnd.ollama.image.model" + convertAdapter, err := convert.DetectNPZ(file.Name()) + if err != nil { + return nil, err + } + tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") if err != nil { return nil, err } defer os.RemoveAll(tempDir) - if err := extractFromZipFile(tempDir, file, fn); err != nil { - return nil, err + if !convertAdapter { + if err := extractFromZipFile(tempDir, file, fn); err != nil { + return nil, err + } + } else { + layerType = "application/vnd.ollama.image.adapter" } fn(api.ProgressResponse{Status: "converting model"}) @@ -150,15 +160,22 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a defer temp.Close() defer os.Remove(temp.Name()) - if err := convert.Convert(tempDir, temp); err != nil { - return nil, err + if convertAdapter { + slog.Info("convert adapter") + if err := convert.ConvertAdapter(file.Name(), temp); err != nil { + return nil, err + } + } else { + if err := convert.Convert(tempDir, temp); err != nil { + return nil, err + } } if _, err := temp.Seek(0, io.SeekStart); err != nil { return nil, err } - layer, err := NewLayer(temp, "application/vnd.ollama.image.model") + layer, err := NewLayer(temp, layerType) if err != nil { return nil, err } @@ -177,7 +194,11 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a layers = append(layers, &layerGGML{layer, ggml}) intermediateBlobs[digest] = layer.Digest - return detectChatTemplate(layers) + if !convertAdapter { + return detectChatTemplate(layers) + } + + return layers, nil } func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {