quantize
This commit is contained in:
parent
ee2b9b076c
commit
c63b4ecbf7
@ -1 +1 @@
|
||||
Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584
|
||||
Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c
|
23
llm/llm.go
23
llm/llm.go
@ -10,6 +10,10 @@ package llm
|
||||
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
// bool update_quantize_progress(int progress, void* data) {
|
||||
// *((int*)data) = progress;
|
||||
// return true;
|
||||
// }
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
@ -21,7 +25,7 @@ func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile string, ftype fileType) error {
|
||||
func Quantize(infile, outfile string, ftype fileType, count *int) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
@ -32,6 +36,23 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
||||
params.nthread = -1
|
||||
params.ftype = ftype.Value()
|
||||
|
||||
// Initialize "global" to store progress
|
||||
store := C.malloc(C.sizeof(int))
|
||||
|
||||
params.quantize_callback_data = store
|
||||
params.quantize_callback = C.update_quantize_progress
|
||||
|
||||
go func () {
|
||||
for {
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
if params.quantize_callback_data == nil {
|
||||
return
|
||||
} else {
|
||||
*count = int(*(*C.int)(store))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
@ -413,6 +414,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
return fmt.Errorf("invalid model reference: %s", c.Args)
|
||||
}
|
||||
|
||||
var quantized int
|
||||
tensorCount := 0
|
||||
for _, baseLayer := range baseLayers {
|
||||
if quantization != "" &&
|
||||
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
|
||||
@ -423,11 +426,27 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
return err
|
||||
}
|
||||
|
||||
tensorCount := len(baseLayer.GGML.Tensors())
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("quantizing model %d tensors", tensorCount),
|
||||
Quantize: quantization,
|
||||
})
|
||||
tensorCount = len(baseLayer.GGML.Tensors())
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("quantizing model %d%%", quantized*100/tensorCount),
|
||||
Quantize: quantization})
|
||||
case <-done:
|
||||
fn(api.ProgressResponse{
|
||||
Status: "quantizing model",
|
||||
Quantize: quantization})
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ft := baseLayer.GGML.KV().FileType()
|
||||
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
||||
@ -447,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
|
||||
// Quantizes per layer
|
||||
// Save total quantized tensors
|
||||
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
|
||||
if err := llm.Quantize(blob, temp.Name(), want, &quantized); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user