quantize
This commit is contained in:
parent
ee2b9b076c
commit
c63b4ecbf7
@ -1 +1 @@
|
|||||||
Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584
|
Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c
|
23
llm/llm.go
23
llm/llm.go
@ -10,6 +10,10 @@ package llm
|
|||||||
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "llama.h"
|
// #include "llama.h"
|
||||||
|
// bool update_quantize_progress(int progress, void* data) {
|
||||||
|
// *((int*)data) = progress;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -21,7 +25,7 @@ func SystemInfo() string {
|
|||||||
return C.GoString(C.llama_print_system_info())
|
return C.GoString(C.llama_print_system_info())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Quantize(infile, outfile string, ftype fileType) error {
|
func Quantize(infile, outfile string, ftype fileType, count *int) error {
|
||||||
cinfile := C.CString(infile)
|
cinfile := C.CString(infile)
|
||||||
defer C.free(unsafe.Pointer(cinfile))
|
defer C.free(unsafe.Pointer(cinfile))
|
||||||
|
|
||||||
@ -32,6 +36,23 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
|||||||
params.nthread = -1
|
params.nthread = -1
|
||||||
params.ftype = ftype.Value()
|
params.ftype = ftype.Value()
|
||||||
|
|
||||||
|
// Initialize "global" to store progress
|
||||||
|
store := C.malloc(C.sizeof(int))
|
||||||
|
|
||||||
|
params.quantize_callback_data = store
|
||||||
|
params.quantize_callback = C.update_quantize_progress
|
||||||
|
|
||||||
|
go func () {
|
||||||
|
for {
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
if params.quantize_callback_data == nil {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
*count = int(*(*C.int)(store))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
"github.com/ollama/ollama/auth"
|
||||||
@ -413,6 +414,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
return fmt.Errorf("invalid model reference: %s", c.Args)
|
return fmt.Errorf("invalid model reference: %s", c.Args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var quantized int
|
||||||
|
tensorCount := 0
|
||||||
for _, baseLayer := range baseLayers {
|
for _, baseLayer := range baseLayers {
|
||||||
if quantization != "" &&
|
if quantization != "" &&
|
||||||
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
|
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
|
||||||
@ -423,11 +426,27 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorCount := len(baseLayer.GGML.Tensors())
|
tensorCount = len(baseLayer.GGML.Tensors())
|
||||||
fn(api.ProgressResponse{
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
Status: fmt.Sprintf("quantizing model %d tensors", tensorCount),
|
done := make(chan struct{})
|
||||||
Quantize: quantization,
|
defer close(done)
|
||||||
})
|
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
fn(api.ProgressResponse{
|
||||||
|
Status: fmt.Sprintf("quantizing model %d%%", quantized*100/tensorCount),
|
||||||
|
Quantize: quantization})
|
||||||
|
case <-done:
|
||||||
|
fn(api.ProgressResponse{
|
||||||
|
Status: "quantizing model",
|
||||||
|
Quantize: quantization})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ft := baseLayer.GGML.KV().FileType()
|
ft := baseLayer.GGML.KV().FileType()
|
||||||
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
||||||
@ -447,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
|
|
||||||
// Quantizes per layer
|
// Quantizes per layer
|
||||||
// Save total quantized tensors
|
// Save total quantized tensors
|
||||||
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
|
if err := llm.Quantize(blob, temp.Name(), want, &quantized); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user