tensor count

This commit is contained in:
Josh Yan 2024-07-09 11:02:58 -07:00
parent 1344843515
commit bec9100f32

View File

@ -413,8 +413,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return fmt.Errorf("invalid model reference: %s", c.Args)
}
layerCount := len(baseLayers)
for i, baseLayer := range baseLayers {
for _, baseLayer := range baseLayers {
if quantization != "" &&
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
baseLayer.GGML != nil &&
@ -424,6 +423,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err
}
tensorCount := len(baseLayer.GGML.Tensors())
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d tensors", tensorCount),
Quantize: quantization,
})
ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models")
@ -463,6 +468,11 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
baseLayer.Layer = layer
baseLayer.GGML = ggml
}
fn(api.ProgressResponse{
Status: "quantizing model done",
Quantize: quantization,
})
}
if baseLayer.GGML != nil {
@ -473,18 +483,14 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
}
fn(api.ProgressResponse{
/* fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", i*100/layerCount),
Quantize: quantization,
})
}) */
layers = append(layers, baseLayer.Layer)
}
fn(api.ProgressResponse{
Status: "quantizing model done",
Quantize: quantization,
})
case "license", "template", "system":
if c.Name != "license" {
// replace