This commit is contained in:
Josh Yan 2024-07-10 09:58:30 -07:00
parent e59453982d
commit b0e4e8d76c
2 changed files with 4 additions and 3 deletions

View File

@ -28,7 +28,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info())
}
func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse) ) error {
func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile))
@ -59,7 +59,7 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
select {
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", int(*((*C.float)(store))*100)),
Status: fmt.Sprintf("quantizing model %d/%d", int(*((*C.float)(store))), tensorCount),
Quantize: "quant",
})
fmt.Println("Progress: ", *((*C.float)(store)))

View File

@ -422,6 +422,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if err != nil {
return err
}
tensorCount := len(baseLayer.GGML.Tensors())
ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
@ -441,7 +442,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
// Quantizes per layer
// Save total quantized tensors
if err := llm.Quantize(blob, temp.Name(), want, fn); err != nil {
if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
return err
}