From 2e1394e40526aba0fa868250bcafaf884a71bf71 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sun, 16 Jul 2023 15:18:57 -0700 Subject: [PATCH] add progressbar for model pulls --- cmd/cmd.go | 20 ++++++++++++++- server/images.go | 62 ++++++++++++++++++++++++++++++++++----------- server/models.go | 65 ------------------------------------------------ server/routes.go | 4 --- 4 files changed, 67 insertions(+), 84 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 913a48930..99033614f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -90,9 +90,27 @@ func RunPull(cmd *cobra.Command, args []string) error { func pull(model string) error { client := api.NewClient() + var bar *progressbar.ProgressBar + + currentLayer := "" request := api.PullRequest{Name: model} fn := func(resp api.PullProgress) error { - fmt.Println(resp.Status) + if resp.Digest != currentLayer && resp.Digest != "" { + if currentLayer != "" { + fmt.Println() + } + currentLayer = resp.Digest + layerStr := resp.Digest[7:23] + "..." + bar = progressbar.DefaultBytes( + int64(resp.Total), + "pulling "+layerStr, + ) + } else if resp.Digest == currentLayer && resp.Digest != "" { + bar.Set(resp.Completed) + } else { + currentLayer = "" + fmt.Println(resp.Status) + } return nil } diff --git a/server/images.go b/server/images.go index e10e95302..d92ce095c 100644 --- a/server/images.go +++ b/server/images.go @@ -5,13 +5,16 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" + "io/ioutil" "log" "net/http" "os" "path" "path/filepath" + "strconv" "strings" "github.com/jmorganca/ollama/api" @@ -536,7 +539,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T for _, layer := range layers { fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total)) - if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil { + if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil { + fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0) return err } completed += layer.Size @@ -717,7 +721,7 @@ func uploadBlob(location string, layer *Layer, username string, password string) return nil } -func downloadBlob(registryURL, repoName, digest, username, password string) error { +func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { home, err := os.UserHomeDir() if err != nil { return err @@ -732,8 +736,22 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro return nil } + var size int64 + + fi, err := os.Stat(fp + "-partial") + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + } + url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest) - headers := map[string]string{} + headers := map[string]string{ + "Range": fmt.Sprintf("bytes=%d-", size), + } resp, err := makeRequest("GET", url, headers, nil, username, password) if err != nil { @@ -742,10 +760,8 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro } defer resp.Body.Close() - // TODO: handle range requests to make this resumable - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + body, _ := ioutil.ReadAll(resp.Body) return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) } @@ -754,16 +770,34 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro return fmt.Errorf("make blobs directory: %w", err) } - out, err := os.Create(fp) + out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { - log.Printf("couldn't create %s", fp) - return err + panic(err) } defer out.Close() - _, err = io.Copy(out, resp.Body) - if err != nil { - return err + remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + completed := size + total := remaining + completed + + for { + fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total)) + if completed >= total { + fmt.Printf("finished downloading\n") + err = os.Rename(fp+"-partial", fp) + if err != nil { + fmt.Printf("error: %v\n", err) + fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1) + return err + } + break + } + + n, err := io.CopyN(out, resp.Body, 8192) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + completed += n } log.Printf("success getting %s\n", digest) @@ -790,7 +824,7 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader, if len(via) >= 10 { return fmt.Errorf("too many redirects") } - log.Printf("redirected to: %s", req.URL) + log.Printf("redirected to: %s\n", req.URL) return nil }, } diff --git a/server/models.go b/server/models.go index a8d175f63..c76c09e22 100644 --- a/server/models.go +++ b/server/models.go @@ -1,12 +1,6 @@ package server import ( - "fmt" - "os" - "path" - "path/filepath" - "strconv" - "github.com/jmorganca/ollama/api" ) @@ -26,62 +20,3 @@ type Model struct { License string `json:"license"` } -func saveModel(model *Model, fn func(total, completed int64)) error { - // this models cache directory is created by the server on startup - - client := &http.Client{} - req, err := http.NewRequest("GET", model.URL, nil) - if err != nil { - return fmt.Errorf("failed to download model: %w", err) - } - - var size int64 - - // completed file doesn't exist, check partial file - fi, err := os.Stat(model.TempFile()) - switch { - case errors.Is(err, os.ErrNotExist): - // noop, file doesn't exist so create it - case err != nil: - return fmt.Errorf("stat: %w", err) - default: - size = fi.Size() - } - - req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size)) - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to download model: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - return fmt.Errorf("failed to download model: %s", resp.Status) - } - - out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - panic(err) - } - defer out.Close() - - remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - completed := size - - total := remaining + completed - - for { - fn(total, completed) - if completed >= total { - return os.Rename(model.TempFile(), model.FullName()) - } - - n, err := io.CopyN(out, resp.Body, 8192) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - completed += n - } -} diff --git a/server/routes.go b/server/routes.go index e59b00b14..fea96df8c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -19,10 +19,6 @@ import ( "github.com/jmorganca/ollama/llama" ) -//go:embed templates/* -var templatesFS embed.FS -var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) - func cacheDir() string { home, err := os.UserHomeDir() if err != nil {