From 8476ef2bd87a1585e0ba33b33fa6a63a4dace7f0 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 15 Jul 2024 10:44:35 -0700 Subject: [PATCH] atomic for race --- llm/llm.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/llm/llm.go b/llm/llm.go index f9da47586..b9ae53063 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -9,14 +9,18 @@ package llm // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #include +// #include // #include "llama.h" // bool update_quantize_progress(float progress, void* data) { -// *((float*)data) = progress; -// return true; +// atomic_int* atomicData = (atomic_int*)data; +// int intProgress = *((int*)&progress); +// atomic_store(atomicData, intProgress); +// return true; // } import "C" import ( "fmt" + "sync/atomic" "time" "unsafe" @@ -39,21 +43,17 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR params.nthread = -1 params.ftype = ftype.Value() - // race condition with `store` - // use atomicint/float idk yet - // use set in the C. - // Initialize "global" to store progress - store := C.malloc(C.sizeof_float) - defer C.free(store) + store := (*int32)(C.malloc(C.sizeof_int)) + defer C.free(unsafe.Pointer(store)) // Initialize store value, e.g., setting initial progress to 0 - *(*C.float)(store) = 0.0 + atomic.StoreInt32(store, 0) - params.quantize_callback_data = store + params.quantize_callback_data = unsafe.Pointer(store) params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress) - ticker := time.NewTicker(60 * time.Millisecond) + ticker := time.NewTicker(30 * time.Millisecond) done := make(chan struct{}) defer close(done) @@ -62,11 +62,13 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR for { select { case <-ticker.C: + progressInt := atomic.LoadInt32(store) + progress := *(*float32)(unsafe.Pointer(&progressInt)) fn(api.ProgressResponse{ - Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount), + Status: fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount), Quantize: "quant", }) - fmt.Println("Progress: ", *((*C.float)(store))) + fmt.Println("Progress: ", progress) case <-done: fn(api.ProgressResponse{ Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),