test
This commit is contained in:
parent
40c0f9612e
commit
903e9df46f
102
llm/gguf_test.go
Normal file
102
llm/gguf_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestGGUFRewrite(t *testing.T) {
|
||||
tests := []string{
|
||||
"glm2.gguf",
|
||||
"nutiny.gguf",
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
tt := tests[i]
|
||||
t.Run(tt, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := filepath.Join("testdata", tt)
|
||||
|
||||
if _, err := os.Stat(p); err != nil {
|
||||
t.Fatalf("%s not found", p)
|
||||
}
|
||||
|
||||
ggml, err := decodeGGML(t, p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ggml2, err := rewriteGGML(t, ggml, p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if cmp.Diff(ggml, ggml2) != "" {
|
||||
t.Fatal(cmp.Diff(ggml, ggml2))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decodeGGML(t *testing.T, p string) (*GGML, error) {
|
||||
f, err := os.Open(p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(f, math.MaxInt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return ggml, nil
|
||||
}
|
||||
|
||||
func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) {
|
||||
var tensors Tensors
|
||||
temp, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer temp.Close()
|
||||
|
||||
for _, tensor := range ggml.Tensors() {
|
||||
shape := make([]uint64, len(tensor.Shape))
|
||||
for i := range len(tensor.Shape) {
|
||||
shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
|
||||
}
|
||||
|
||||
tensors = append(tensors, &Tensor{
|
||||
Name: tensor.Name,
|
||||
Kind: tensor.Kind,
|
||||
Shape: shape,
|
||||
|
||||
WriterTo: TensorWriter{
|
||||
Reader: io.NewSectionReader(temp, int64(tensor.Offset), int64(tensor.Size())),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
reader := &GGUFWriter{
|
||||
KV: ggml.KV(),
|
||||
// Update .Tensors
|
||||
Tensors: tensors,
|
||||
}
|
||||
|
||||
_, err = io.Copy(temp, reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ggml2, _, err := DecodeGGML(temp, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ggml2, nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user