create blobs in parallel

This commit is contained in:
Michael Yang 2025-04-04 18:26:49 -07:00
parent e9e5f61c45
commit 588a97dbef
2 changed files with 34 additions and 15 deletions

View File

@ -22,6 +22,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@ -31,6 +32,7 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Name = args[0]
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@ -117,26 +119,43 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
// copy files since we'll be modifying the map
temp := req.Files
req.Files = make(map[string]string, len(temp))
for f, digest := range temp {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
mu.Lock()
req.Files[filepath.Base(f)] = digest
mu.Unlock()
return nil
})
}
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
// copy files since we'll be modifying the map
temp = req.Adapters
req.Adapters = make(map[string]string, len(temp))
for f, digest := range temp {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
mu.Lock()
req.Adapters[filepath.Base(f)] = digest
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
bars := make(map[string]*progress.Bar)
@ -213,7 +232,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
}
}()
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil

View File

@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return
}
if req.Name != "test-model" {
if req.Model != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}