quantize percentage

This commit is contained in:
Josh Yan 2024-07-08 14:51:58 -07:00
parent 6bab0e2368
commit e87eafe5cd
3 changed files with 21 additions and 3 deletions

View File

@ -267,6 +267,7 @@ type PullRequest struct {
type ProgressResponse struct { type ProgressResponse struct {
Status string `json:"status"` Status string `json:"status"`
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
Quantize string `json:"quantize,omitempty"`
Total int64 `json:"total,omitempty"` Total int64 `json:"total,omitempty"`
Completed int64 `json:"completed,omitempty"` Completed int64 `json:"completed,omitempty"`
} }

View File

@ -125,6 +125,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
var quantizeSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@ -137,6 +138,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else if resp.Quantize != "" {
if quantizeSpin != nil {
quantizeSpin.SetMessage(resp.Status)
} else {
quantizeSpin = progress.NewSpinner(resp.Status)
p.Add("quantize", quantizeSpin)
}
} else if status != resp.Status { } else if status != resp.Status {
spinner.Stop() spinner.Stop()

View File

@ -413,7 +413,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return fmt.Errorf("invalid model reference: %s", c.Args) return fmt.Errorf("invalid model reference: %s", c.Args)
} }
for _, baseLayer := range baseLayers { layerCount := len(baseLayers)
for i, baseLayer := range baseLayers {
if quantization != "" && if quantization != "" &&
baseLayer.MediaType == "application/vnd.ollama.image.model" && baseLayer.MediaType == "application/vnd.ollama.image.model" &&
baseLayer.GGML != nil && baseLayer.GGML != nil &&
@ -427,8 +428,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if !slices.Contains([]string{"F16", "F32"}, ft.String()) { if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models") return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft { } else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
blob, err := GetBlobsPath(baseLayer.Digest) blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil { if err != nil {
return err return err
@ -472,8 +471,18 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
} }
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", i*100/layerCount),
Quantize: quantization,
})
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", 100),
Quantize: quantization,
})
case "license", "template", "system": case "license", "template", "system":
if c.Name != "license" { if c.Name != "license" {
// replace // replace