From 4c9a160a0850b1cf57bdccddc6cf2f5d35aeabbc Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Fri, 12 Jul 2024 11:52:10 -0700 Subject: [PATCH] race --- llm/llm.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llm/llm.go b/llm/llm.go index d700a2cb6..f9da47586 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -17,8 +17,8 @@ package llm import "C" import ( "fmt" - "unsafe" "time" + "unsafe" "github.com/ollama/ollama/api" ) @@ -28,19 +28,21 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { +func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) coutfile := C.CString(outfile) defer C.free(unsafe.Pointer(coutfile)) - - params := C.llama_model_quantize_default_params() params.nthread = -1 params.ftype = ftype.Value() + // race condition with `store` + // use atomicint/float idk yet + // use set in the C. + // Initialize "global" to store progress store := C.malloc(C.sizeof_float) defer C.free(store) @@ -63,7 +65,7 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR fn(api.ProgressResponse{ Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount), Quantize: "quant", - }) + }) fmt.Println("Progress: ", *((*C.float)(store))) case <-done: fn(api.ProgressResponse{