From c63b4ecbf77974648ac58e1c1614a6089c26ca71 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 9 Jul 2024 15:35:44 -0700 Subject: [PATCH] quantize --- llm/llama.cpp | 2 +- llm/llm.go | 23 ++++++++++++++++++++++- server/images.go | 31 +++++++++++++++++++++++++------ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/llm/llama.cpp b/llm/llama.cpp index a8db2a9ce..7c26775ad 160000 --- a/llm/llama.cpp +++ b/llm/llama.cpp @@ -1 +1 @@ -Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584 +Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c diff --git a/llm/llm.go b/llm/llm.go index f2a5e557a..37eee2bc3 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -10,6 +10,10 @@ package llm // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #include // #include "llama.h" +// bool update_quantize_progress(int progress, void* data) { +// *((int*)data) = progress; +// return true; +// } import "C" import ( "fmt" @@ -21,7 +25,7 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile string, ftype fileType) error { +func Quantize(infile, outfile string, ftype fileType, count *int) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) @@ -32,6 +36,23 @@ func Quantize(infile, outfile string, ftype fileType) error { params.nthread = -1 params.ftype = ftype.Value() + // Initialize "global" to store progress + store := C.malloc(C.sizeof(int)) + + params.quantize_callback_data = store + params.quantize_callback = C.update_quantize_progress + + go func () { + for { + time.Sleep(60 * time.Millisecond) + if params.quantize_callback_data == nil { + return + } else { + *count = int(*(*C.int)(store)) + } + } + }() + if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 { return fmt.Errorf("llama_model_quantize: %d", rc) } diff --git a/server/images.go b/server/images.go index 40d1a8e8b..a4dcfc9da 100644 --- a/server/images.go +++ b/server/images.go @@ -21,6 +21,7 @@ import ( "slices" "strconv" "strings" + "time" "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" @@ -413,6 +414,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return fmt.Errorf("invalid model reference: %s", c.Args) } + var quantized int + tensorCount := 0 for _, baseLayer := range baseLayers { if quantization != "" && baseLayer.MediaType == "application/vnd.ollama.image.model" && @@ -423,11 +426,27 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - tensorCount := len(baseLayer.GGML.Tensors()) - fn(api.ProgressResponse{ - Status: fmt.Sprintf("quantizing model %d tensors", tensorCount), - Quantize: quantization, - }) + tensorCount = len(baseLayer.GGML.Tensors()) + ticker := time.NewTicker(60 * time.Millisecond) + done := make(chan struct{}) + defer close(done) + + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + fn(api.ProgressResponse{ + Status: fmt.Sprintf("quantizing model %d%%", quantized*100/tensorCount), + Quantize: quantization}) + case <-done: + fn(api.ProgressResponse{ + Status: "quantizing model", + Quantize: quantization}) + } + return + } + }() ft := baseLayer.GGML.KV().FileType() if !slices.Contains([]string{"F16", "F32"}, ft.String()) { @@ -447,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio // Quantizes per layer // Save total quantized tensors - if err := llm.Quantize(blob, temp.Name(), want); err != nil { + if err := llm.Quantize(blob, temp.Name(), want, &quantized); err != nil { return err }