From 903e9df46fc0ebd0cca188080f6d0804bd0f3d88 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 15 Jul 2024 11:46:49 -0700 Subject: [PATCH] test --- llm/gguf_test.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 llm/gguf_test.go diff --git a/llm/gguf_test.go b/llm/gguf_test.go new file mode 100644 index 000000000..faff66ae4 --- /dev/null +++ b/llm/gguf_test.go @@ -0,0 +1,102 @@ +package llm + +import ( + "io" + "math" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestGGUFRewrite(t *testing.T) { + tests := []string{ + "glm2.gguf", + "nutiny.gguf", + } + + for i := range tests { + tt := tests[i] + t.Run(tt, func(t *testing.T) { + t.Parallel() + p := filepath.Join("testdata", tt) + + if _, err := os.Stat(p); err != nil { + t.Fatalf("%s not found", p) + } + + ggml, err := decodeGGML(t, p) + if err != nil { + t.Fatal(err) + } + + ggml2, err := rewriteGGML(t, ggml, p) + if err != nil { + t.Fatal(err) + } + + if cmp.Diff(ggml, ggml2) != "" { + t.Fatal(cmp.Diff(ggml, ggml2)) + } + }) + } +} + +func decodeGGML(t *testing.T, p string) (*GGML, error) { + f, err := os.Open(p) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + ggml, _, err := DecodeGGML(f, math.MaxInt) + if err != nil { + t.Fatal(err) + } + return ggml, nil +} + +func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) { + var tensors Tensors + temp, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer temp.Close() + + for _, tensor := range ggml.Tensors() { + shape := make([]uint64, len(tensor.Shape)) + for i := range len(tensor.Shape) { + shape[i] = tensor.Shape[len(tensor.Shape)-i-1] + } + + tensors = append(tensors, &Tensor{ + Name: tensor.Name, + Kind: tensor.Kind, + Shape: shape, + + WriterTo: TensorWriter{ + Reader: io.NewSectionReader(temp, int64(tensor.Offset), int64(tensor.Size())), + }, + }) + } + + reader := &GGUFWriter{ + KV: ggml.KV(), + // Update .Tensors + Tensors: tensors, + } + + _, err = io.Copy(temp, reader) + if err != nil { + t.Fatal(err) + } + + ggml2, _, err := DecodeGGML(temp, -1) + if err != nil { + t.Fatal(err) + } + + return ggml2, nil +}