diff --git a/api/types.go b/api/types.go index 87844c67c..419c1d17f 100644 --- a/api/types.go +++ b/api/types.go @@ -267,6 +267,7 @@ type PullRequest struct { type ProgressResponse struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` + Quantize string `json:"quantize,omitempty"` Total int64 `json:"total,omitempty"` Completed int64 `json:"completed,omitempty"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 4ac689681..1de40bd78 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -125,6 +125,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } bars := make(map[string]*progress.Bar) + var quantizeSpin *progress.Spinner fn := func(resp api.ProgressResponse) error { if resp.Digest != "" { spinner.Stop() @@ -137,6 +138,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } bar.Set(resp.Completed) + } else if resp.Quantize != "" { + if quantizeSpin != nil { + quantizeSpin.SetMessage(resp.Status) + } else { + quantizeSpin = progress.NewSpinner(resp.Status) + p.Add("quantize", quantizeSpin) + } } else if status != resp.Status { spinner.Stop() diff --git a/server/images.go b/server/images.go index 791a81a15..13a3192ae 100644 --- a/server/images.go +++ b/server/images.go @@ -413,7 +413,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return fmt.Errorf("invalid model reference: %s", c.Args) } - for _, baseLayer := range baseLayers { + layerCount := len(baseLayers) + for i, baseLayer := range baseLayers { if quantization != "" && baseLayer.MediaType == "application/vnd.ollama.image.model" && baseLayer.GGML != nil && @@ -427,8 +428,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio if !slices.Contains([]string{"F16", "F32"}, ft.String()) { return errors.New("quantization is only supported for F16 and F32 models") } else if want != ft { - fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) - blob, err := GetBlobsPath(baseLayer.Digest) if err != nil { return err @@ -472,8 +471,18 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) } + fn(api.ProgressResponse{ + Status: fmt.Sprintf("quantizing model %d%%", i*100/layerCount), + Quantize: quantization, + }) + layers = append(layers, baseLayer.Layer) } + + fn(api.ProgressResponse{ + Status: fmt.Sprintf("quantizing model %d%%", 100), + Quantize: quantization, + }) case "license", "template", "system": if c.Name != "license" { // replace