clean
This commit is contained in:
parent
873f334783
commit
703ecccc6b
@ -7,17 +7,15 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGGUFRewrite(t *testing.T) {
|
func TestGGUFRewrite(t *testing.T) {
|
||||||
tests := []string{
|
// to test this GGUF Rewrite, add gguf files to /llm/testdata
|
||||||
"phi3.gguf",
|
// add the name of the file to the tests slice
|
||||||
"nutiny.gguf",
|
tests := []string{}
|
||||||
}
|
|
||||||
|
|
||||||
for i := range tests {
|
for i := range tests {
|
||||||
tt := tests[i]
|
tt := tests[i]
|
||||||
@ -35,6 +33,7 @@ func TestGGUFRewrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
// decode original gguf
|
||||||
ggml, _, err := decodeGGML(t, f)
|
ggml, _, err := decodeGGML(t, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -46,35 +45,22 @@ func TestGGUFRewrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer temp.Close()
|
defer temp.Close()
|
||||||
|
|
||||||
n, ggml2, err := rewriteGGML(t, ggml, temp, f)
|
_, ggml2, err := rewriteGGML(t, ggml, temp, f)
|
||||||
|
|
||||||
/* if n != m {
|
|
||||||
t.Fatalf("n: %d, m: %d", n, m)
|
|
||||||
} */
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
//t.Fatal("FULL SIZE JFAKFJJEFJAJFLAEJJAFAJKLFJ", n)
|
|
||||||
|
|
||||||
if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok {
|
if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok {
|
||||||
if cmp.Diff(diff, diff2) != "" {
|
if cmp.Diff(diff, diff2) != "" {
|
||||||
t.Fatalf("\n%s,\n%s\n%s\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], diff["token_embd.weight size"], diff["token_embd.weight offset"], cmp.Diff(diff, diff2))
|
t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatDiff(diff map[string]string) string {
|
func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
|
||||||
var builder strings.Builder
|
|
||||||
for k, v := range diff {
|
|
||||||
builder.WriteString(fmt.Sprintf("%s: %s\n", k, v))
|
|
||||||
}
|
|
||||||
return builder.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
|
|
||||||
diff := make(map[string]string)
|
diff := make(map[string]string)
|
||||||
diff2 := make(map[string]string)
|
diff2 := make(map[string]string)
|
||||||
|
|
||||||
@ -87,15 +73,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range kv1 {
|
for k, v := range kv1 {
|
||||||
// if v2, ok := kv2[k]; !ok {
|
|
||||||
// diff[k] = fmt.Sprintf("missing key %s", k)
|
|
||||||
// } else if v != v2 {
|
|
||||||
// diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values)
|
|
||||||
// diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values))
|
|
||||||
// diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10])
|
|
||||||
// diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:])
|
|
||||||
// diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values)
|
|
||||||
|
|
||||||
switch t := v.(type) {
|
switch t := v.(type) {
|
||||||
case *array:
|
case *array:
|
||||||
if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
|
if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
|
||||||
@ -106,8 +83,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri
|
|||||||
diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
|
diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t1 := ggml1.Tensors()
|
t1 := ggml1.Tensors()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user