This commit is contained in:
Josh Yan 2024-07-12 11:52:10 -07:00
parent 657a1102fc
commit 4c9a160a08

View File

@ -17,8 +17,8 @@ package llm
import "C" import "C"
import ( import (
"fmt" "fmt"
"unsafe"
"time" "time"
"unsafe"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -28,19 +28,21 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
coutfile := C.CString(outfile) coutfile := C.CString(outfile)
defer C.free(unsafe.Pointer(coutfile)) defer C.free(unsafe.Pointer(coutfile))
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// race condition with `store`
// use atomicint/float idk yet
// use set in the C.
// Initialize "global" to store progress // Initialize "global" to store progress
store := C.malloc(C.sizeof_float) store := C.malloc(C.sizeof_float)
defer C.free(store) defer C.free(store)
@ -63,7 +65,7 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount), Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount),
Quantize: "quant", Quantize: "quant",
}) })
fmt.Println("Progress: ", *((*C.float)(store))) fmt.Println("Progress: ", *((*C.float)(store)))
case <-done: case <-done:
fn(api.ProgressResponse{ fn(api.ProgressResponse{