atomic for race

This commit is contained in:
Josh Yan 2024-07-15 10:44:35 -07:00
parent 4c9a160a08
commit 8476ef2bd8

View File

@ -9,14 +9,18 @@ package llm
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include <stdatomic.h>
// #include "llama.h" // #include "llama.h"
// bool update_quantize_progress(float progress, void* data) { // bool update_quantize_progress(float progress, void* data) {
// *((float*)data) = progress; // atomic_int* atomicData = (atomic_int*)data;
// return true; // int intProgress = *((int*)&progress);
// atomic_store(atomicData, intProgress);
// return true;
// } // }
import "C" import "C"
import ( import (
"fmt" "fmt"
"sync/atomic"
"time" "time"
"unsafe" "unsafe"
@ -39,21 +43,17 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// race condition with `store`
// use atomicint/float idk yet
// use set in the C.
// Initialize "global" to store progress // Initialize "global" to store progress
store := C.malloc(C.sizeof_float) store := (*int32)(C.malloc(C.sizeof_int))
defer C.free(store) defer C.free(unsafe.Pointer(store))
// Initialize store value, e.g., setting initial progress to 0 // Initialize store value, e.g., setting initial progress to 0
*(*C.float)(store) = 0.0 atomic.StoreInt32(store, 0)
params.quantize_callback_data = store params.quantize_callback_data = unsafe.Pointer(store)
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress) params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
ticker := time.NewTicker(60 * time.Millisecond) ticker := time.NewTicker(30 * time.Millisecond)
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)
@ -62,11 +62,13 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
progressInt := atomic.LoadInt32(store)
progress := *(*float32)(unsafe.Pointer(&progressInt))
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount), Status: fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount),
Quantize: "quant", Quantize: "quant",
}) })
fmt.Println("Progress: ", *((*C.float)(store))) fmt.Println("Progress: ", progress)
case <-done: case <-done:
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount), Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),