add progressbar for model pulls
This commit is contained in:
parent
95cc9a11db
commit
2e1394e405
20
cmd/cmd.go
20
cmd/cmd.go
@ -90,9 +90,27 @@ func RunPull(cmd *cobra.Command, args []string) error {
|
||||
func pull(model string) error {
|
||||
client := api.NewClient()
|
||||
|
||||
var bar *progressbar.ProgressBar
|
||||
|
||||
currentLayer := ""
|
||||
request := api.PullRequest{Name: model}
|
||||
fn := func(resp api.PullProgress) error {
|
||||
fmt.Println(resp.Status)
|
||||
if resp.Digest != currentLayer && resp.Digest != "" {
|
||||
if currentLayer != "" {
|
||||
fmt.Println()
|
||||
}
|
||||
currentLayer = resp.Digest
|
||||
layerStr := resp.Digest[7:23] + "..."
|
||||
bar = progressbar.DefaultBytes(
|
||||
int64(resp.Total),
|
||||
"pulling "+layerStr,
|
||||
)
|
||||
} else if resp.Digest == currentLayer && resp.Digest != "" {
|
||||
bar.Set(resp.Completed)
|
||||
} else {
|
||||
currentLayer = ""
|
||||
fmt.Println(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -5,13 +5,16 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
@ -536,7 +539,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
|
||||
|
||||
for _, layer := range layers {
|
||||
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
|
||||
if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil {
|
||||
if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
|
||||
fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
|
||||
return err
|
||||
}
|
||||
completed += layer.Size
|
||||
@ -717,7 +721,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func downloadBlob(registryURL, repoName, digest, username, password string) error {
|
||||
func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -732,8 +736,22 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
var size int64
|
||||
|
||||
fi, err := os.Stat(fp + "-partial")
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
// noop, file doesn't exist so create it
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat: %w", err)
|
||||
default:
|
||||
size = fi.Size()
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
|
||||
headers := map[string]string{}
|
||||
headers := map[string]string{
|
||||
"Range": fmt.Sprintf("bytes=%d-", size),
|
||||
}
|
||||
|
||||
resp, err := makeRequest("GET", url, headers, nil, username, password)
|
||||
if err != nil {
|
||||
@ -742,10 +760,8 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// TODO: handle range requests to make this resumable
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
@ -754,16 +770,34 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
||||
return fmt.Errorf("make blobs directory: %w", err)
|
||||
}
|
||||
|
||||
out, err := os.Create(fp)
|
||||
out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
log.Printf("couldn't create %s", fp)
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
completed := size
|
||||
total := remaining + completed
|
||||
|
||||
for {
|
||||
fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
|
||||
if completed >= total {
|
||||
fmt.Printf("finished downloading\n")
|
||||
err = os.Rename(fp+"-partial", fp)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
n, err := io.CopyN(out, resp.Body, 8192)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
completed += n
|
||||
}
|
||||
|
||||
log.Printf("success getting %s\n", digest)
|
||||
@ -790,7 +824,7 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
log.Printf("redirected to: %s", req.URL)
|
||||
log.Printf("redirected to: %s\n", req.URL)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
@ -1,12 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
@ -26,62 +20,3 @@ type Model struct {
|
||||
License string `json:"license"`
|
||||
}
|
||||
|
||||
func saveModel(model *Model, fn func(total, completed int64)) error {
|
||||
// this models cache directory is created by the server on startup
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", model.URL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download model: %w", err)
|
||||
}
|
||||
|
||||
var size int64
|
||||
|
||||
// completed file doesn't exist, check partial file
|
||||
fi, err := os.Stat(model.TempFile())
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
// noop, file doesn't exist so create it
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat: %w", err)
|
||||
default:
|
||||
size = fi.Size()
|
||||
}
|
||||
|
||||
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download model: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||
}
|
||||
|
||||
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
completed := size
|
||||
|
||||
total := remaining + completed
|
||||
|
||||
for {
|
||||
fn(total, completed)
|
||||
if completed >= total {
|
||||
return os.Rename(model.TempFile(), model.FullName())
|
||||
}
|
||||
|
||||
n, err := io.CopyN(out, resp.Body, 8192)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
completed += n
|
||||
}
|
||||
}
|
||||
|
@ -19,10 +19,6 @@ import (
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
)
|
||||
|
||||
//go:embed templates/*
|
||||
var templatesFS embed.FS
|
||||
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
||||
|
||||
func cacheDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user