From ae65cc8dea60a5f09ed0a2a8a8f19d3ca439094d Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Wed, 3 Jul 2024 11:22:23 -0700 Subject: [PATCH] progress --- cmd/cmd.go | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index d99a2b4ae..ad9b2023d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -77,7 +77,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - status := "transferring model data" + status := "" spinner := progress.NewSpinner(status) p.Add(status, spinner) @@ -113,7 +113,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { path = tempfile } - + spinner.Stop() digest, err := createBlob(cmd, client, path) if err != nil { return err @@ -274,6 +274,13 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er } defer bin.Close() + // Get file info to retrieve the size + fileInfo, err := bin.Stat() + if err != nil { + return "", err + } + fileSize := fileInfo.Size() + hash := sha256.New() if _, err := io.Copy(hash, bin); err != nil { return "", err @@ -283,6 +290,29 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er return "", err } + var pw progressWriter + // Create a progress bar and start a goroutine to update it + p := progress.NewProgress(os.Stderr) + bar := progress.NewBar("transferring model data...", fileSize, 0) + p.Add("", bar) + + ticker := time.NewTicker(60 * time.Millisecond) + done := make(chan struct{}) + defer p.Stop() + + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + bar.Set(pw.n) + case <-done: + bar.Set(fileSize) + return + } + } + }() + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) // We check if we can find the models directory locally @@ -312,12 +342,21 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er } // If at any point copying the blob over locally fails, we default to the copy through the server - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } return digest, nil } +type progressWriter struct { + n int64 +} + +func (w *progressWriter) Write(p []byte) (n int, err error) { + w.n += int64(len(p)) + return len(p), nil +} + func getLocalPath(ctx context.Context, digest string) (string, error) { ollamaHost := envconfig.Host