From 3803ecb6a699492ec7b82de8e2714bff8fd65e5f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 30 Jun 2024 10:24:31 -0700 Subject: [PATCH] cmd build context --- cmd/cmd.go | 129 +++++++++++++++++++++++++++++------------------------ 1 file changed, 71 insertions(+), 58 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index c898c7db6..02320d2a4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -3,6 +3,7 @@ package cmd import ( "archive/zip" "bytes" + "cmp" "context" "crypto/ed25519" "crypto/rand" @@ -11,6 +12,7 @@ import ( "errors" "fmt" "io" + "io/fs" "log" "math" "net" @@ -70,30 +72,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - home, err := os.UserHomeDir() - if err != nil { - return err - } - status := "transferring model data" spinner := progress.NewSpinner(status) p.Add(status, spinner) + createCtx, err := cmd.Flags().GetString("context") + if err != nil { + return err + } + + createCtx = cmp.Or(createCtx, filepath.Dir(filename)) + fsys := os.DirFS(createCtx) + for i := range modelfile.Commands { switch modelfile.Commands[i].Name { case "model", "adapter": - path := modelfile.Commands[i].Args - if path == "~" { - path = home - } else if strings.HasPrefix(path, "~/") { - path = filepath.Join(home, path[2:]) - } + p := filepath.Clean(modelfile.Commands[i].Args) - if !filepath.IsAbs(path) { - path = filepath.Join(filepath.Dir(filename), path) - } - - fi, err := os.Stat(path) + fi, err := fs.Stat(fsys, p) if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" { continue } else if err != nil { @@ -103,16 +99,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error { if fi.IsDir() { // this is likely a safetensors or pytorch directory // TODO make this work w/ adapters - tempfile, err := tempZipFiles(path) + sub, err := fs.Sub(fsys, p) if err != nil { return err } - defer os.RemoveAll(tempfile) - path = tempfile + temp, err := os.CreateTemp(createCtx, "*.zip") + if err != nil { + return err + } + defer temp.Close() + defer os.RemoveAll(temp.Name()) + + if err := zipFiles(sub, temp); err != nil { + return err + } + + p, err = filepath.Rel(createCtx, temp.Name()) + if err != nil { + return err + } } - digest, err := createBlob(cmd, client, path) + digest, err := createBlob(cmd, client, fsys, p) if err != nil { return err } @@ -155,42 +164,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } -func tempZipFiles(path string) (string, error) { - tempfile, err := os.CreateTemp("", "ollama-tf") - if err != nil { - return "", err - } - defer tempfile.Close() - - detectContentType := func(path string) (string, error) { - f, err := os.Open(path) +func zipFiles(fsys fs.FS, w io.Writer) error { + detectContentType := func(name string) (string, error) { + f, err := fsys.Open(name) if err != nil { return "", err } defer f.Close() - var b bytes.Buffer - b.Grow(512) - - if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { + bts, err := io.ReadAll(io.LimitReader(f, 512)) + if err != nil { return "", err } - contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") + contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";") return contentType, nil } glob := func(pattern, contentType string) ([]string, error) { - matches, err := filepath.Glob(pattern) + matches, err := fs.Glob(fsys, pattern) if err != nil { return nil, err } - for _, safetensor := range matches { - if ct, err := detectContentType(safetensor); err != nil { + for _, match := range matches { + if ct, err := detectContentType(match); err != nil { return nil, err } else if ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) + return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) } } @@ -198,73 +199,73 @@ func tempZipFiles(path string) (string, error) { } var files []string - if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { + if st, _ := glob("model*.safetensors", "application/octet-stream"); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) - } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { + } else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin files = append(files, pt...) - } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { + } else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers consolidated.x.pth, consolidated.pth files = append(files, pt...) } else { - return "", errors.New("no safetensors or torch files found") + return errors.New("no safetensors or torch files found") } // add configuration files, json files are detected as text/plain - js, err := glob(filepath.Join(path, "*.json"), "text/plain") + js, err := glob("*.json", "text/plain") if err != nil { - return "", err + return err } files = append(files, js...) - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 { // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob // tokenizer.model might be a unresolved git lfs reference; error if it is files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + } else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 { // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) files = append(files, tks...) } - zipfile := zip.NewWriter(tempfile) + zipfile := zip.NewWriter(w) defer zipfile.Close() for _, file := range files { - f, err := os.Open(file) + f, err := fsys.Open(file) if err != nil { - return "", err + return err } defer f.Close() fi, err := f.Stat() if err != nil { - return "", err + return err } zfi, err := zip.FileInfoHeader(fi) if err != nil { - return "", err + return err } zf, err := zipfile.CreateHeader(zfi) if err != nil { - return "", err + return err } if _, err := io.Copy(zf, f); err != nil { - return "", err + return err } } - return tempfile.Name(), nil + return nil } -func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { - bin, err := os.Open(path) +func sha256sum(fsys fs.FS, name string) (string, error) { + bin, err := fsys.Open(name) if err != nil { return "", err } @@ -275,14 +276,25 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er return "", err } - if _, err := bin.Seek(0, io.SeekStart); err != nil { + return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil +} + +func createBlob(cmd *cobra.Command, client *api.Client, fsys fs.FS, name string) (string, error) { + bin, err := fsys.Open(name) + if err != nil { + return "", err + } + defer bin.Close() + + digest, err := sha256sum(fsys, name) + if err != nil { return "", err } - digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { return "", err } + return digest, nil } @@ -1226,6 +1238,7 @@ func NewCLI() *cobra.Command { createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)") + createCmd.Flags().StringP("context", "C", "", "Context for the model") showCmd := &cobra.Command{ Use: "show MODEL",