This commit is contained in:
Josh Yan 2024-07-16 16:35:15 -07:00
parent 703ecccc6b
commit 6ee22d5080
2 changed files with 66 additions and 79 deletions

View File

@ -911,7 +911,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error {
} }
default: default:
fmt.Println("type is", v)
return fmt.Errorf("improper type for '%s'", k) return fmt.Errorf("improper type for '%s'", k)
} }
@ -919,5 +918,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error {
} }
func ggufPadding(offset, align int64) int64 { func ggufPadding(offset, align int64) int64 {
// we mod twice in the case offset%align = 0
return (align - offset%align) % align return (align - offset%align) % align
} }

View File

@ -12,10 +12,14 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
// TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files
// To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice
// Should comment //sort.Sort(tensors) in gguf.go
// This creates a temporary file in /llm/testdata that will deleted only if the test passes
func TestGGUFRewrite(t *testing.T) { func TestGGUFRewrite(t *testing.T) {
// to test this GGUF Rewrite, add gguf files to /llm/testdata tests := []string{
// add the name of the file to the tests slice "nutiny.gguf",
tests := []string{} }
for i := range tests { for i := range tests {
tt := tests[i] tt := tests[i]
@ -24,169 +28,152 @@ func TestGGUFRewrite(t *testing.T) {
p := filepath.Join("testdata", tt) p := filepath.Join("testdata", tt)
if _, err := os.Stat(p); err != nil { if _, err := os.Stat(p); err != nil {
t.Fatalf("%s not found", p) t.Skip("file not found", p)
} }
f, err := os.Open(p) wantFile, err := os.Open(p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer f.Close() defer wantFile.Close()
// decode original gguf // decode original gguf
ggml, _, err := decodeGGML(t, f) _, wantGGML, err := decodeGGML(t, wantFile)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
temp, err := os.CreateTemp("testdata", "2"+tt) gotFile, err := os.CreateTemp("testdata", tt)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer temp.Close() defer func() {
gotFile.Close()
if !t.Failed() {
os.Remove(gotFile.Name())
}
}()
_, ggml2, err := rewriteGGML(t, ggml, temp, f) _, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok { diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile)
if cmp.Diff(diff, diff2) != "" { if cmp.Diff(diff, diff2) != "" {
t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2)) t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
} }
}
}) })
} }
} }
func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) {
diff := make(map[string]string) got := make(map[string]string)
diff2 := make(map[string]string) want := make(map[string]string)
kv1 := ggml1.KV() gotKV := gotGGML.KV()
kv2 := ggml2.KV() wantKV := wantGGML.KV()
if len(kv1) != len(kv2) { if len(gotKV) != len(wantKV) {
diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2)) t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV))
fmt.Println("lenKV", diff["lenKV"])
} }
for k, v := range kv1 { for k, v := range gotKV {
switch t := v.(type) { switch t := v.(type) {
case *array: case *array:
if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" { if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" {
diff[k] = diffy got[k] = diffy
} }
default: default:
if v != kv2[k] { if v != wantKV[k] {
diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k]) got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k])
} }
} }
} }
t1 := ggml1.Tensors() gotTensors := gotGGML.Tensors().Items
t2 := ggml2.Tensors() gotOffset := gotGGML.Tensors().Offset
wantTensors := wantGGML.Tensors().Items
wantOffset := wantGGML.Tensors().Offset
if len(t1.Items) != len(t2.Items) { if len(gotTensors) != len(wantTensors) {
diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1.Items), len(t2.Items)) got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors))
} }
for _, tensor := range t1.Items { for _, tensor := range gotTensors {
sha256sum := sha256.New() sha256sum := sha256.New()
sr := io.NewSectionReader(f, t1.Offset+int64(tensor.Offset), int64(tensor.Size())) sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size()))
var s int64 var s int64
s, err := io.Copy(sha256sum, sr) s, err := io.Copy(sha256sum, sr)
if err != nil { if err != nil {
fmt.Println(err) t.Fatalf("error: %v", err)
} }
diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
diff[tensor.Name+" size"] = fmt.Sprintf("%d", s) got[tensor.Name+" size"] = fmt.Sprintf("%d", s)
diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
} }
/* sha256Sum2 := sha256.New() for _, tensor := range wantTensors {
sr1 := io.NewSectionReader(f2, 0, n)
s1, err := io.Copy(sha256Sum2, sr1)
if err != nil {
return nil, nil, true
}
sha256Sum3 := sha256.New()
sr2 := io.NewSectionReader(f, 0, n)
s2, err := io.Copy(sha256Sum3, sr2)
if err != nil {
return nil, nil, true
}
diff["sha"] = fmt.Sprintf("%d", s1)
diff2["sha"] = fmt.Sprintf("%d", s2) */
for _, tensor := range t2.Items {
sha256sum := sha256.New() sha256sum := sha256.New()
var s int64 var s int64
sr := io.NewSectionReader(f2, t1.Offset+int64(tensor.Offset), int64(tensor.Size())) sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size()))
s, err := io.Copy(sha256sum, sr) s, err := io.Copy(sha256sum, sr)
if err != nil { if err != nil {
fmt.Println(err) t.Fatalf("error: %v", err)
} }
diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s) want[tensor.Name+" size"] = fmt.Sprintf("%d", s)
diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
}
return got, want
} }
return diff, diff2, len(diff) == 0
} func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) {
func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
ggml, n, err := DecodeGGML(f, math.MaxInt) ggml, n, err := DecodeGGML(f, math.MaxInt)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return ggml, n, nil return n, ggml, nil
} }
func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) { func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) {
var tensors []*Tensor var tensors []*Tensor
fmt.Println("11111111111111111111111111111111111111111")
for _, tensor := range ggml.Tensors().Items { for _, tensor := range ggml.Tensors().Items {
shape := make([]uint64, len(tensor.Shape)) shape := make([]uint64, len(tensor.Shape))
for i := range len(tensor.Shape) { for i := range len(tensor.Shape) {
shape[i] = tensor.Shape[len(tensor.Shape)-i-1] shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
} }
fmt.Println("tensors", tensor.Name, shape, tensor.Kind, tensor.Offset)
fmt.Println(ggml.Tensors().Offset)
tensors = append(tensors, &Tensor{ tensors = append(tensors, &Tensor{
Name: tensor.Name, Name: tensor.Name,
Kind: tensor.Kind, Kind: tensor.Kind,
Shape: shape, Shape: shape,
WriterTo: TensorWriter{ WriterTo: TensorWriter{
Reader: io.NewSectionReader(f, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())), Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
}, },
}) })
} }
reader := &GGUFWriter{ reader := &GGUFWriter{
KV: ggml.KV(), KV: ggml.KV(),
// Update .Tensors
Tensors: Tensors{ Tensors: Tensors{
Items: tensors, Items: tensors,
Offset: ggml.Tensors().Offset, Offset: ggml.Tensors().Offset,
}, },
} }
n, err := io.Copy(temp, reader) n, err := io.Copy(gotFile, reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Println(n, " is my offset") file, err := os.Open(gotFile.Name())
file, err := os.Open(temp.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }