add progressbar for model pulls

This commit is contained in:
Patrick Devine 2023-07-16 15:18:57 -07:00
parent 95cc9a11db
commit 2e1394e405
4 changed files with 67 additions and 84 deletions

View File

@ -90,9 +90,27 @@ func RunPull(cmd *cobra.Command, args []string) error {
func pull(model string) error {
client := api.NewClient()
var bar *progressbar.ProgressBar
currentLayer := ""
request := api.PullRequest{Name: model}
fn := func(resp api.PullProgress) error {
fmt.Println(resp.Status)
if resp.Digest != currentLayer && resp.Digest != "" {
if currentLayer != "" {
fmt.Println()
}
currentLayer = resp.Digest
layerStr := resp.Digest[7:23] + "..."
bar = progressbar.DefaultBytes(
int64(resp.Total),
"pulling "+layerStr,
)
} else if resp.Digest == currentLayer && resp.Digest != "" {
bar.Set(resp.Completed)
} else {
currentLayer = ""
fmt.Println(resp.Status)
}
return nil
}

View File

@ -5,13 +5,16 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/jmorganca/ollama/api"
@ -536,7 +539,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
for _, layer := range layers {
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil {
if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
return err
}
completed += layer.Size
@ -717,7 +721,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
return nil
}
func downloadBlob(registryURL, repoName, digest, username, password string) error {
func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
home, err := os.UserHomeDir()
if err != nil {
return err
@ -732,8 +736,22 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
return nil
}
var size int64
fi, err := os.Stat(fp + "-partial")
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
}
url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
headers := map[string]string{}
headers := map[string]string{
"Range": fmt.Sprintf("bytes=%d-", size),
}
resp, err := makeRequest("GET", url, headers, nil, username, password)
if err != nil {
@ -742,10 +760,8 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
}
defer resp.Body.Close()
// TODO: handle range requests to make this resumable
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
body, _ := ioutil.ReadAll(resp.Body)
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
}
@ -754,16 +770,34 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
return fmt.Errorf("make blobs directory: %w", err)
}
out, err := os.Create(fp)
out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
log.Printf("couldn't create %s", fp)
return err
panic(err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return err
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
completed := size
total := remaining + completed
for {
fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
if completed >= total {
fmt.Printf("finished downloading\n")
err = os.Rename(fp+"-partial", fp)
if err != nil {
fmt.Printf("error: %v\n", err)
fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
return err
}
break
}
n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
completed += n
}
log.Printf("success getting %s\n", digest)
@ -790,7 +824,7 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
if len(via) >= 10 {
return fmt.Errorf("too many redirects")
}
log.Printf("redirected to: %s", req.URL)
log.Printf("redirected to: %s\n", req.URL)
return nil
},
}

View File

@ -1,12 +1,6 @@
package server
import (
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"github.com/jmorganca/ollama/api"
)
@ -26,62 +20,3 @@ type Model struct {
License string `json:"license"`
}
func saveModel(model *Model, fn func(total, completed int64)) error {
// this models cache directory is created by the server on startup
client := &http.Client{}
req, err := http.NewRequest("GET", model.URL, nil)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
var size int64
// completed file doesn't exist, check partial file
fi, err := os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download model: %s", resp.Status)
}
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
panic(err)
}
defer out.Close()
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
completed := size
total := remaining + completed
for {
fn(total, completed)
if completed >= total {
return os.Rename(model.TempFile(), model.FullName())
}
n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
completed += n
}
}

View File

@ -19,10 +19,6 @@ import (
"github.com/jmorganca/ollama/llama"
)
//go:embed templates/*
var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {