test
This commit is contained in:
parent
903e9df46f
commit
25be20949c
@ -312,10 +312,12 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
maxArraySize = 1024
|
maxArraySize = 1024
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("errored 1")
|
||||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||||
|
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||||
|
fmt.Println("errored 2")
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,19 +332,24 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
||||||
default:
|
default:
|
||||||
|
fmt.Println("errored 3")
|
||||||
return nil, 0, errors.New("invalid file magic")
|
return nil, 0, errors.New("invalid file magic")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("valid magic")
|
||||||
model, err := c.Decode(rs)
|
model, err := c.Decode(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("valid decode")
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println("invalid seek")
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("valid seek")
|
||||||
// final model type
|
// final model type
|
||||||
return &GGML{
|
return &GGML{
|
||||||
container: c,
|
container: c,
|
||||||
|
12
llm/gguf.go
12
llm/gguf.go
@ -8,7 +8,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -141,13 +140,11 @@ func (llm *gguf) numKV() uint64 {
|
|||||||
|
|
||||||
func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||||
// decode key-values
|
// decode key-values
|
||||||
fmt.Println(llm.numKV())
|
|
||||||
for i := 0; uint64(i) < llm.numKV(); i++ {
|
for i := 0; uint64(i) < llm.numKV(); i++ {
|
||||||
k, err := readGGUFString(llm, rs)
|
k, err := readGGUFString(llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Printf("k: %#v\n", k)
|
|
||||||
|
|
||||||
t, err := readGGUF[uint32](llm, rs)
|
t, err := readGGUF[uint32](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -214,7 +211,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
}
|
}
|
||||||
shape = append(shape, shapeVal)
|
shape = append(shape, shapeVal)
|
||||||
}
|
}
|
||||||
fmt.Println("tensor ", name, " shape ", shape)
|
|
||||||
|
|
||||||
kind, err := readGGUF[uint32](llm, rs)
|
kind, err := readGGUF[uint32](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -226,6 +222,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
return fmt.Errorf("failed to read tensor offset: %w", err)
|
return fmt.Errorf("failed to read tensor offset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("tensor", name, shape, kind, offset)
|
||||||
tensor := Tensor{
|
tensor := Tensor{
|
||||||
Name: name,
|
Name: name,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
@ -764,8 +761,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
//sort.Sort(gguf.Tensors)
|
||||||
sort.Sort(gguf.Tensors)
|
|
||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
for _, t := range gguf.Tensors {
|
for _, t := range gguf.Tensors {
|
||||||
@ -775,6 +771,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
s += t.Size()
|
s += t.Size()
|
||||||
}
|
}
|
||||||
|
tensorOffset := wo.offset
|
||||||
|
|
||||||
for _, t := range gguf.Tensors {
|
for _, t := range gguf.Tensors {
|
||||||
if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
|
if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
|
||||||
@ -782,7 +779,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, nil
|
return int64(tensorOffset), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
|
func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
|
||||||
@ -797,7 +794,6 @@ func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
|
|||||||
if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println("tensor ", t.Name, " shape ", t.Shape)
|
|
||||||
|
|
||||||
for i := range len(t.Shape) {
|
for i := range len(t.Shape) {
|
||||||
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
||||||
|
154
llm/gguf_test.go
154
llm/gguf_test.go
@ -1,10 +1,13 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@ -13,7 +16,6 @@ import (
|
|||||||
func TestGGUFRewrite(t *testing.T) {
|
func TestGGUFRewrite(t *testing.T) {
|
||||||
tests := []string{
|
tests := []string{
|
||||||
"glm2.gguf",
|
"glm2.gguf",
|
||||||
"nutiny.gguf",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range tests {
|
for i := range tests {
|
||||||
@ -26,44 +28,144 @@ func TestGGUFRewrite(t *testing.T) {
|
|||||||
t.Fatalf("%s not found", p)
|
t.Fatalf("%s not found", p)
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml, err := decodeGGML(t, p)
|
f, err := os.Open(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
ggml, m, err := decodeGGML(t, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml2, err := rewriteGGML(t, ggml, p)
|
temp, err := os.CreateTemp("testdata", "2"+tt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer temp.Close()
|
||||||
|
|
||||||
|
n, ggml2, err := rewriteGGML(t, ggml, temp)
|
||||||
|
|
||||||
|
if n != m {
|
||||||
|
t.Fatalf("n: %d, m: %d", n, m)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmp.Diff(ggml, ggml2) != "" {
|
if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok {
|
||||||
t.Fatal(cmp.Diff(ggml, ggml2))
|
if cmp.Diff(diff, diff2) != "" {
|
||||||
|
t.Fatalf("\n%s,\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], cmp.Diff(diff, diff2))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* // Reset the file offset to the beginning
|
||||||
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content1, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read file1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content2, err := io.ReadAll(temp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read file1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if byteCmp := cmp.Diff(content1, content2); byteCmp != "" {
|
||||||
|
t.Fatalf("content diff: %s", byteCmp)
|
||||||
|
} */
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeGGML(t *testing.T, p string) (*GGML, error) {
|
func formatDiff(diff map[string]string) string {
|
||||||
f, err := os.Open(p)
|
var builder strings.Builder
|
||||||
if err != nil {
|
for k, v := range diff {
|
||||||
t.Fatal(err)
|
builder.WriteString(fmt.Sprintf("%s: %s\n", k, v))
|
||||||
}
|
}
|
||||||
defer f.Close()
|
return builder.String()
|
||||||
|
|
||||||
ggml, _, err := DecodeGGML(f, math.MaxInt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return ggml, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) {
|
func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
|
||||||
var tensors Tensors
|
diff := make(map[string]string)
|
||||||
temp, err := os.Create(path)
|
diff2 := make(map[string]string)
|
||||||
|
|
||||||
|
kv1 := ggml1.KV()
|
||||||
|
kv2 := ggml2.KV()
|
||||||
|
|
||||||
|
if len(kv1) != len(kv2) {
|
||||||
|
diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
|
||||||
|
fmt.Println("lenKV", diff["lenKV"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range kv1 {
|
||||||
|
// if v2, ok := kv2[k]; !ok {
|
||||||
|
// diff[k] = fmt.Sprintf("missing key %s", k)
|
||||||
|
// } else if v != v2 {
|
||||||
|
// diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values)
|
||||||
|
// diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values))
|
||||||
|
// diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10])
|
||||||
|
// diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:])
|
||||||
|
// diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values)
|
||||||
|
|
||||||
|
switch t := v.(type) {
|
||||||
|
case *array:
|
||||||
|
if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
|
||||||
|
diff[k] = diffy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
t1 := ggml1.Tensors()
|
||||||
|
t2 := ggml2.Tensors()
|
||||||
|
|
||||||
|
if len(t1) != len(t2) {
|
||||||
|
diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1), len(t2))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tensor := range t1 {
|
||||||
|
sha256sum := sha256.New()
|
||||||
|
sr := io.NewSectionReader(f, n+int64(tensor.Offset), int64(tensor.Size()))
|
||||||
|
if _, err := io.Copy(sha256sum, sr); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tensor := range t2 {
|
||||||
|
sha256sum := sha256.New()
|
||||||
|
sr := io.NewSectionReader(f2, n+int64(tensor.Offset), int64(tensor.Size()))
|
||||||
|
if _, err := io.Copy(sha256sum, sr); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
||||||
|
}
|
||||||
|
return diff, diff2, len(diff) == 0
|
||||||
|
|
||||||
|
}
|
||||||
|
func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
|
||||||
|
|
||||||
|
ggml, n, err := DecodeGGML(f, math.MaxInt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer temp.Close()
|
return ggml, n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File) (int64, *GGML, error) {
|
||||||
|
var tensors Tensors
|
||||||
|
|
||||||
for _, tensor := range ggml.Tensors() {
|
for _, tensor := range ggml.Tensors() {
|
||||||
shape := make([]uint64, len(tensor.Shape))
|
shape := make([]uint64, len(tensor.Shape))
|
||||||
@ -88,15 +190,21 @@ func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) {
|
|||||||
Tensors: tensors,
|
Tensors: tensors,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.Copy(temp, reader)
|
n, err := io.Copy(temp, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml2, _, err := DecodeGGML(temp, -1)
|
fmt.Println(n)
|
||||||
|
temp.Seek(0, io.SeekStart)
|
||||||
|
file, err := os.Open(temp.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ggml2, n, err := DecodeGGML(file, math.MaxInt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml2, nil
|
return n, ggml2, nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user