This commit is contained in:
Josh Yan 2024-07-09 15:35:44 -07:00
parent ee2b9b076c
commit c63b4ecbf7
3 changed files with 48 additions and 8 deletions

@ -1 +1 @@
Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584
Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c

View File

@ -10,6 +10,10 @@ package llm
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h>
// #include "llama.h"
// bool update_quantize_progress(int progress, void* data) {
// *((int*)data) = progress;
// return true;
// }
import "C"
import (
"fmt"
@ -21,7 +25,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info())
}
func Quantize(infile, outfile string, ftype fileType) error {
func Quantize(infile, outfile string, ftype fileType, count *int) error {
cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile))
@ -32,6 +36,23 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.nthread = -1
params.ftype = ftype.Value()
// Initialize "global" to store progress
store := C.malloc(C.sizeof(int))
params.quantize_callback_data = store
params.quantize_callback = C.update_quantize_progress
go func () {
for {
time.Sleep(60 * time.Millisecond)
if params.quantize_callback_data == nil {
return
} else {
*count = int(*(*C.int)(store))
}
}
}()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("llama_model_quantize: %d", rc)
}

View File

@ -21,6 +21,7 @@ import (
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
@ -413,6 +414,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return fmt.Errorf("invalid model reference: %s", c.Args)
}
var quantized int
tensorCount := 0
for _, baseLayer := range baseLayers {
if quantization != "" &&
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
@ -423,11 +426,27 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err
}
tensorCount := len(baseLayer.GGML.Tensors())
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d tensors", tensorCount),
Quantize: quantization,
})
tensorCount = len(baseLayer.GGML.Tensors())
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", quantized*100/tensorCount),
Quantize: quantization})
case <-done:
fn(api.ProgressResponse{
Status: "quantizing model",
Quantize: quantization})
}
return
}
}()
ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
@ -447,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
// Quantizes per layer
// Save total quantized tensors
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
if err := llm.Quantize(blob, temp.Name(), want, &quantized); err != nil {
return err
}