From 25be20949c4ca425d662c406eabb966c6fafac42 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 15 Jul 2024 15:08:24 -0700 Subject: [PATCH] test --- llm/ggml.go | 7 +++ llm/gguf.go | 12 ++-- llm/gguf_test.go | 154 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 142 insertions(+), 31 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index 1fdd3c071..b2973250b 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -312,10 +312,12 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { maxArraySize = 1024 } + fmt.Println("errored 1") rs = bufioutil.NewBufferedSeeker(rs, 32<<10) var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { + fmt.Println("errored 2") return nil, 0, err } @@ -330,19 +332,24 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { case FILE_MAGIC_GGUF_BE: c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} default: + fmt.Println("errored 3") return nil, 0, errors.New("invalid file magic") } + fmt.Println("valid magic") model, err := c.Decode(rs) if err != nil { return nil, 0, err } + fmt.Println("valid decode") offset, err := rs.Seek(0, io.SeekCurrent) if err != nil { + fmt.Println("invalid seek") return nil, 0, err } + fmt.Println("valid seek") // final model type return &GGML{ container: c, diff --git a/llm/gguf.go b/llm/gguf.go index ca94d8fdc..a283b5386 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -8,7 +8,6 @@ import ( "io" "log/slog" "slices" - "sort" "strings" "golang.org/x/exp/maps" @@ -141,13 +140,11 @@ func (llm *gguf) numKV() uint64 { func (llm *gguf) Decode(rs io.ReadSeeker) error { // decode key-values - fmt.Println(llm.numKV()) for i := 0; uint64(i) < llm.numKV(); i++ { k, err := readGGUFString(llm, rs) if err != nil { return err } - fmt.Printf("k: %#v\n", k) t, err := readGGUF[uint32](llm, rs) if err != nil { @@ -214,7 +211,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error { } shape = append(shape, shapeVal) } - fmt.Println("tensor ", name, " shape ", shape) kind, err := readGGUF[uint32](llm, rs) if err != nil { @@ -226,6 +222,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error { return fmt.Errorf("failed to read tensor offset: %w", err) } + fmt.Println("tensor", name, shape, kind, offset) tensor := Tensor{ Name: name, Kind: kind, @@ -764,8 +761,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } } } - - sort.Sort(gguf.Tensors) + //sort.Sort(gguf.Tensors) var s uint64 for _, t := range gguf.Tensors { @@ -775,6 +771,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } s += t.Size() } + tensorOffset := wo.offset for _, t := range gguf.Tensors { if err := ggufWriteTensor(wo, t, wo.offset); err != nil { @@ -782,7 +779,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } } - return 0, nil + return int64(tensorOffset), nil } func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error { @@ -797,7 +794,6 @@ func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error { if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil { return err } - fmt.Println("tensor ", t.Name, " shape ", t.Shape) for i := range len(t.Shape) { if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil { diff --git a/llm/gguf_test.go b/llm/gguf_test.go index faff66ae4..67b163a42 100644 --- a/llm/gguf_test.go +++ b/llm/gguf_test.go @@ -1,10 +1,13 @@ package llm import ( + "crypto/sha256" + "fmt" "io" "math" "os" "path/filepath" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -13,7 +16,6 @@ import ( func TestGGUFRewrite(t *testing.T) { tests := []string{ "glm2.gguf", - "nutiny.gguf", } for i := range tests { @@ -26,44 +28,144 @@ func TestGGUFRewrite(t *testing.T) { t.Fatalf("%s not found", p) } - ggml, err := decodeGGML(t, p) + f, err := os.Open(p) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + ggml, m, err := decodeGGML(t, f) if err != nil { t.Fatal(err) } - ggml2, err := rewriteGGML(t, ggml, p) + temp, err := os.CreateTemp("testdata", "2"+tt) + if err != nil { + t.Fatal(err) + } + defer temp.Close() + + n, ggml2, err := rewriteGGML(t, ggml, temp) + + if n != m { + t.Fatalf("n: %d, m: %d", n, m) + } + if err != nil { t.Fatal(err) } - if cmp.Diff(ggml, ggml2) != "" { - t.Fatal(cmp.Diff(ggml, ggml2)) + if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok { + if cmp.Diff(diff, diff2) != "" { + t.Fatalf("\n%s,\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], cmp.Diff(diff, diff2)) + } } + + /* // Reset the file offset to the beginning + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + if _, err := temp.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + + content1, err := io.ReadAll(f) + if err != nil { + t.Fatalf("failed to read file1: %v", err) + } + + content2, err := io.ReadAll(temp) + if err != nil { + t.Fatalf("failed to read file1: %v", err) + } + + if byteCmp := cmp.Diff(content1, content2); byteCmp != "" { + t.Fatalf("content diff: %s", byteCmp) + } */ }) } } -func decodeGGML(t *testing.T, p string) (*GGML, error) { - f, err := os.Open(p) - if err != nil { - t.Fatal(err) +func formatDiff(diff map[string]string) string { + var builder strings.Builder + for k, v := range diff { + builder.WriteString(fmt.Sprintf("%s: %s\n", k, v)) } - defer f.Close() - - ggml, _, err := DecodeGGML(f, math.MaxInt) - if err != nil { - t.Fatal(err) - } - return ggml, nil + return builder.String() } -func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) { - var tensors Tensors - temp, err := os.Create(path) +func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { + diff := make(map[string]string) + diff2 := make(map[string]string) + + kv1 := ggml1.KV() + kv2 := ggml2.KV() + + if len(kv1) != len(kv2) { + diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2)) + fmt.Println("lenKV", diff["lenKV"]) + } + + for k, v := range kv1 { + // if v2, ok := kv2[k]; !ok { + // diff[k] = fmt.Sprintf("missing key %s", k) + // } else if v != v2 { + // diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values) + // diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values)) + // diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10]) + // diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:]) + // diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values) + + switch t := v.(type) { + case *array: + if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" { + diff[k] = diffy + } + } + + // } + } + + t1 := ggml1.Tensors() + t2 := ggml2.Tensors() + + if len(t1) != len(t2) { + diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1), len(t2)) + } + + for _, tensor := range t1 { + sha256sum := sha256.New() + sr := io.NewSectionReader(f, n+int64(tensor.Offset), int64(tensor.Size())) + if _, err := io.Copy(sha256sum, sr); err != nil { + fmt.Println(err) + } + + diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) + } + + for _, tensor := range t2 { + sha256sum := sha256.New() + sr := io.NewSectionReader(f2, n+int64(tensor.Offset), int64(tensor.Size())) + if _, err := io.Copy(sha256sum, sr); err != nil { + fmt.Println(err) + } + + diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) + } + return diff, diff2, len(diff) == 0 + +} +func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) { + + ggml, n, err := DecodeGGML(f, math.MaxInt) if err != nil { t.Fatal(err) } - defer temp.Close() + return ggml, n, nil +} + +func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File) (int64, *GGML, error) { + var tensors Tensors for _, tensor := range ggml.Tensors() { shape := make([]uint64, len(tensor.Shape)) @@ -88,15 +190,21 @@ func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) { Tensors: tensors, } - _, err = io.Copy(temp, reader) + n, err := io.Copy(temp, reader) if err != nil { t.Fatal(err) } - ggml2, _, err := DecodeGGML(temp, -1) + fmt.Println(n) + temp.Seek(0, io.SeekStart) + file, err := os.Open(temp.Name()) + if err != nil { + t.Fatal(err) + } + ggml2, n, err := DecodeGGML(file, math.MaxInt) if err != nil { t.Fatal(err) } - return ggml2, nil + return n, ggml2, nil }