This commit is contained in:
Josh Yan 2024-07-05 13:23:15 -07:00
parent 784958a1cb
commit b48420b74b
2 changed files with 21 additions and 10 deletions

View File

@ -78,6 +78,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
status := "starting model create..." status := "starting model create..."
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@ -112,10 +115,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) // spinner.Stop()
digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
spinner.SetMessage("transferring model data 100%")
modelfile.Commands[i].Args = "@" + digest modelfile.Commands[i].Args = "@" + digest
} }
@ -261,7 +266,7 @@ func tempZipFiles(path string) (string, error) {
var ErrBlobExists = errors.New("blob exists") var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
@ -286,22 +291,26 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
var pw progressWriter var pw progressWriter
// Create a progress bar and start a goroutine to update it // Create a progress bar and start a goroutine to update it
p := progress.NewProgress(os.Stderr) // JK Let's use a percetage
bar := progress.NewBar("transferring model data...", fileSize, 0)
p.Add("transferring model data", bar) //bar := progress.NewBar("transferring model data...", fileSize, 0)
//p.Add("transferring model data", bar)
status := "transferring model data 0%"
spinner.SetMessage(status)
ticker := time.NewTicker(60 * time.Millisecond) ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{}) done := make(chan struct{})
defer p.Stop() defer close(done)
go func() { go func() {
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
bar.Set(pw.n) spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
case <-done: case <-done:
bar.Set(fileSize) spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", 100))
return return
} }
} }
@ -339,8 +348,6 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
bar.Set(fileSize)
close(done)
return digest, nil return digest, nil
} }

View File

@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message = message
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 { if len(s.message) > 0 {