diff --git a/llm/gguf.go b/llm/gguf.go index cccfc3686..d9a949f94 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -911,7 +911,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error { } default: - fmt.Println("type is", v) return fmt.Errorf("improper type for '%s'", k) } @@ -919,5 +918,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error { } func ggufPadding(offset, align int64) int64 { + // we mod twice in the case offset%align = 0 return (align - offset%align) % align } diff --git a/llm/gguf_test.go b/llm/gguf_test.go index 658276b72..6c0e1685d 100644 --- a/llm/gguf_test.go +++ b/llm/gguf_test.go @@ -12,10 +12,14 @@ import ( "github.com/google/go-cmp/cmp" ) +// TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files +// To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice +// Should comment //sort.Sort(tensors) in gguf.go +// This creates a temporary file in /llm/testdata that will deleted only if the test passes func TestGGUFRewrite(t *testing.T) { - // to test this GGUF Rewrite, add gguf files to /llm/testdata - // add the name of the file to the tests slice - tests := []string{} + tests := []string{ + "nutiny.gguf", + } for i := range tests { tt := tests[i] @@ -24,169 +28,152 @@ func TestGGUFRewrite(t *testing.T) { p := filepath.Join("testdata", tt) if _, err := os.Stat(p); err != nil { - t.Fatalf("%s not found", p) + t.Skip("file not found", p) } - f, err := os.Open(p) + wantFile, err := os.Open(p) if err != nil { t.Fatal(err) } - defer f.Close() + defer wantFile.Close() // decode original gguf - ggml, _, err := decodeGGML(t, f) + _, wantGGML, err := decodeGGML(t, wantFile) if err != nil { t.Fatal(err) } - temp, err := os.CreateTemp("testdata", "2"+tt) + gotFile, err := os.CreateTemp("testdata", tt) if err != nil { t.Fatal(err) } - defer temp.Close() - - _, ggml2, err := rewriteGGML(t, ggml, temp, f) - - if err != nil { - t.Fatal(err) - } - - if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok { - if cmp.Diff(diff, diff2) != "" { - t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2)) + defer func() { + gotFile.Close() + if !t.Failed() { + os.Remove(gotFile.Name()) } + }() + + _, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile) + + if err != nil { + t.Fatal(err) + } + + diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile) + if cmp.Diff(diff, diff2) != "" { + t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2)) } }) } } -func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { - diff := make(map[string]string) - diff2 := make(map[string]string) +func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) { + got := make(map[string]string) + want := make(map[string]string) - kv1 := ggml1.KV() - kv2 := ggml2.KV() + gotKV := gotGGML.KV() + wantKV := wantGGML.KV() - if len(kv1) != len(kv2) { - diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2)) - fmt.Println("lenKV", diff["lenKV"]) + if len(gotKV) != len(wantKV) { + t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV)) } - for k, v := range kv1 { + for k, v := range gotKV { switch t := v.(type) { case *array: - if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" { - diff[k] = diffy + if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" { + got[k] = diffy } default: - if v != kv2[k] { - diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k]) + if v != wantKV[k] { + got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k]) } } } - t1 := ggml1.Tensors() - t2 := ggml2.Tensors() + gotTensors := gotGGML.Tensors().Items + gotOffset := gotGGML.Tensors().Offset + wantTensors := wantGGML.Tensors().Items + wantOffset := wantGGML.Tensors().Offset - if len(t1.Items) != len(t2.Items) { - diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1.Items), len(t2.Items)) + if len(gotTensors) != len(wantTensors) { + got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors)) } - for _, tensor := range t1.Items { + for _, tensor := range gotTensors { sha256sum := sha256.New() - sr := io.NewSectionReader(f, t1.Offset+int64(tensor.Offset), int64(tensor.Size())) + sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size())) var s int64 s, err := io.Copy(sha256sum, sr) if err != nil { - fmt.Println(err) + t.Fatalf("error: %v", err) } - diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) - diff[tensor.Name+" size"] = fmt.Sprintf("%d", s) - diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) + got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) + got[tensor.Name+" size"] = fmt.Sprintf("%d", s) + got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) } - /* sha256Sum2 := sha256.New() - sr1 := io.NewSectionReader(f2, 0, n) - s1, err := io.Copy(sha256Sum2, sr1) - if err != nil { - return nil, nil, true - } - - sha256Sum3 := sha256.New() - sr2 := io.NewSectionReader(f, 0, n) - s2, err := io.Copy(sha256Sum3, sr2) - if err != nil { - return nil, nil, true - } - - diff["sha"] = fmt.Sprintf("%d", s1) - diff2["sha"] = fmt.Sprintf("%d", s2) */ - - for _, tensor := range t2.Items { + for _, tensor := range wantTensors { sha256sum := sha256.New() var s int64 - sr := io.NewSectionReader(f2, t1.Offset+int64(tensor.Offset), int64(tensor.Size())) + sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size())) s, err := io.Copy(sha256sum, sr) if err != nil { - fmt.Println(err) + t.Fatalf("error: %v", err) } - diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) - diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s) - diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) + want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) + want[tensor.Name+" size"] = fmt.Sprintf("%d", s) + want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) } - return diff, diff2, len(diff) == 0 - + return got, want } -func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) { + +func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) { ggml, n, err := DecodeGGML(f, math.MaxInt) if err != nil { t.Fatal(err) } - return ggml, n, nil + return n, ggml, nil } -func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) { +func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) { var tensors []*Tensor - fmt.Println("11111111111111111111111111111111111111111") for _, tensor := range ggml.Tensors().Items { shape := make([]uint64, len(tensor.Shape)) for i := range len(tensor.Shape) { shape[i] = tensor.Shape[len(tensor.Shape)-i-1] } - fmt.Println("tensors", tensor.Name, shape, tensor.Kind, tensor.Offset) - fmt.Println(ggml.Tensors().Offset) tensors = append(tensors, &Tensor{ Name: tensor.Name, Kind: tensor.Kind, Shape: shape, WriterTo: TensorWriter{ - Reader: io.NewSectionReader(f, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())), + Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())), }, }) } reader := &GGUFWriter{ KV: ggml.KV(), - // Update .Tensors Tensors: Tensors{ Items: tensors, Offset: ggml.Tensors().Offset, }, } - n, err := io.Copy(temp, reader) + n, err := io.Copy(gotFile, reader) if err != nil { t.Fatal(err) } - fmt.Println(n, " is my offset") - file, err := os.Open(temp.Name()) + file, err := os.Open(gotFile.Name()) if err != nil { t.Fatal(err) }