From 588a97dbef60a2410cb5201c4b742d10c3761f61 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 4 Apr 2025 18:26:49 -0700 Subject: [PATCH] create blobs in parallel --- cmd/cmd.go | 47 +++++++++++++++++++++++++++++++++-------------- cmd/cmd_test.go | 2 +- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index befe578d6..b4593c47a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -22,6 +22,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -31,6 +32,7 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" "golang.org/x/term" "github.com/ollama/ollama/api" @@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } spinner.Stop() - req.Name = args[0] + req.Model = args[0] quantize, _ := cmd.Flags().GetString("quantize") if quantize != "" { req.Quantize = quantize @@ -117,26 +119,43 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - if len(req.Files) > 0 { - fileMap := map[string]string{} - for f, digest := range req.Files { + var mu sync.Mutex + var g errgroup.Group + g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) + // copy files since we'll be modifying the map + temp := req.Files + req.Files = make(map[string]string, len(temp)) + for f, digest := range temp { + g.Go(func() error { if _, err := createBlob(cmd, client, f, digest, p); err != nil { return err } - fileMap[filepath.Base(f)] = digest - } - req.Files = fileMap + + mu.Lock() + req.Files[filepath.Base(f)] = digest + mu.Unlock() + return nil + }) } - if len(req.Adapters) > 0 { - fileMap := map[string]string{} - for f, digest := range req.Adapters { + // copy files since we'll be modifying the map + temp = req.Adapters + req.Adapters = make(map[string]string, len(temp)) + for f, digest := range temp { + g.Go(func() error { if _, err := createBlob(cmd, client, f, digest, p); err != nil { return err } - fileMap[filepath.Base(f)] = digest - } - req.Adapters = fileMap + + mu.Lock() + req.Adapters[filepath.Base(f)] = digest + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + return err } bars := make(map[string]*progress.Bar) @@ -213,7 +232,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri } }() - if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { + if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } return digest, nil diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 367a35b6b..1cd6ddb40 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) { return } - if req.Name != "test-model" { + if req.Model != "test-model" { t.Errorf("expected model name 'test-model', got %s", req.Name) }