This commit is contained in:
Josh Yan 2024-07-03 11:22:23 -07:00
parent a037528bba
commit ae65cc8dea

View File

@ -77,7 +77,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
status := "transferring model data"
status := ""
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
@ -113,7 +113,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile
}
spinner.Stop()
digest, err := createBlob(cmd, client, path)
if err != nil {
return err
@ -274,6 +274,13 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
@ -283,6 +290,29 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err
}
var pw progressWriter
// Create a progress bar and start a goroutine to update it
p := progress.NewProgress(os.Stderr)
bar := progress.NewBar("transferring model data...", fileSize, 0)
p.Add("", bar)
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer p.Stop()
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
bar.Set(pw.n)
case <-done:
bar.Set(fileSize)
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
// We check if we can find the models directory locally
@ -312,12 +342,21 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct {
n int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n += int64(len(p))
return len(p), nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
ollamaHost := envconfig.Host