diff --git a/llm/llm.go b/llm/llm.go index d700a2cb6..f9da47586 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -17,8 +17,8 @@ package llm import "C" import ( "fmt" - "unsafe" "time" + "unsafe" "github.com/ollama/ollama/api" ) @@ -28,19 +28,21 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { +func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) coutfile := C.CString(outfile) defer C.free(unsafe.Pointer(coutfile)) - - params := C.llama_model_quantize_default_params() params.nthread = -1 params.ftype = ftype.Value() + // race condition with `store` + // use atomicint/float idk yet + // use set in the C. + // Initialize "global" to store progress store := C.malloc(C.sizeof_float) defer C.free(store) @@ -63,7 +65,7 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR fn(api.ProgressResponse{ Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount), Quantize: "quant", - }) + }) fmt.Println("Progress: ", *((*C.float)(store))) case <-done: fn(api.ProgressResponse{