diff --git a/llm/gguf_test.go b/llm/gguf_test.go index 244549d48..658276b72 100644 --- a/llm/gguf_test.go +++ b/llm/gguf_test.go @@ -7,17 +7,15 @@ import ( "math" "os" "path/filepath" - "strings" "testing" "github.com/google/go-cmp/cmp" ) func TestGGUFRewrite(t *testing.T) { - tests := []string{ - "phi3.gguf", - "nutiny.gguf", - } + // to test this GGUF Rewrite, add gguf files to /llm/testdata + // add the name of the file to the tests slice + tests := []string{} for i := range tests { tt := tests[i] @@ -35,6 +33,7 @@ func TestGGUFRewrite(t *testing.T) { } defer f.Close() + // decode original gguf ggml, _, err := decodeGGML(t, f) if err != nil { t.Fatal(err) @@ -46,35 +45,22 @@ func TestGGUFRewrite(t *testing.T) { } defer temp.Close() - n, ggml2, err := rewriteGGML(t, ggml, temp, f) - - /* if n != m { - t.Fatalf("n: %d, m: %d", n, m) - } */ + _, ggml2, err := rewriteGGML(t, ggml, temp, f) if err != nil { t.Fatal(err) } - //t.Fatal("FULL SIZE JFAKFJJEFJAJFLAEJJAFAJKLFJ", n) - if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok { + if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok { if cmp.Diff(diff, diff2) != "" { - t.Fatalf("\n%s,\n%s\n%s\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], diff["token_embd.weight size"], diff["token_embd.weight offset"], cmp.Diff(diff, diff2)) + t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2)) } } }) } } -func formatDiff(diff map[string]string) string { - var builder strings.Builder - for k, v := range diff { - builder.WriteString(fmt.Sprintf("%s: %s\n", k, v)) - } - return builder.String() -} - -func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { +func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { diff := make(map[string]string) diff2 := make(map[string]string) @@ -87,15 +73,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri } for k, v := range kv1 { - // if v2, ok := kv2[k]; !ok { - // diff[k] = fmt.Sprintf("missing key %s", k) - // } else if v != v2 { - // diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values) - // diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values)) - // diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10]) - // diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:]) - // diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values) - switch t := v.(type) { case *array: if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" { @@ -106,8 +83,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k]) } } - - // } } t1 := ggml1.Tensors()