This commit is contained in:
Josh Yan 2024-07-12 15:37:27 -07:00
parent cf57246aba
commit faa3c937cf
2 changed files with 72 additions and 23 deletions

View File

@ -351,6 +351,7 @@ func writeGGUFString(w io.Writer, s string) error {
type array struct { type array struct {
size int size int
values []any values []any
datatype uint32
} }
func (a *array) MarshalJSON() ([]byte, error) { func (a *array) MarshalJSON() ([]byte, error) {
@ -430,7 +431,7 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
return nil, err return nil, err
} }
a := &array{size: int(n)} a := &array{size: int(n), datatype: t}
if llm.canCollectArray(int(n)) { if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n)) a.values = make([]any, int(n))
} }
@ -707,6 +708,17 @@ type GGUFWriter struct {
Tensors Tensors
} }
type writeOffset struct {
io.Writer
offset int
}
func (wo *writeOffset) Write(p []byte) (int, error) {
n, err := wo.Writer.Write(p)
wo.offset += n
return n, err
}
var _ io.Reader = (*GGUFWriter)(nil) var _ io.Reader = (*GGUFWriter)(nil)
var _ io.WriterTo = (*GGUFWriter)(nil) var _ io.WriterTo = (*GGUFWriter)(nil)
@ -716,19 +728,21 @@ func (GGUFWriter) Read([]byte) (int, error) {
} }
func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
if err := binary.Write(w, binary.LittleEndian, []byte("GGUF")); err != nil { wo := &writeOffset{Writer: w}
if err := binary.Write(wo, binary.LittleEndian, []byte("GGUF")); err != nil {
return 0, err return 0, err
} }
if err := binary.Write(w, binary.LittleEndian, uint32(3)); err != nil { if err := binary.Write(wo, binary.LittleEndian, uint32(3)); err != nil {
return 0, err return 0, err
} }
if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil { if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil {
return 0, err return 0, err
} }
if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.KV))); err != nil { if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.KV))); err != nil {
return 0, err return 0, err
} }
@ -736,25 +750,30 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
slices.Sort(keys) slices.Sort(keys)
for _, key := range keys { for _, key := range keys {
if err := ggufWriteKV(w, key, gguf.KV[key]); err != nil { fmt.Println(key)
switch key {
case "general.parameter_count":
continue
default:
if err := ggufWriteKV(wo, key, gguf.KV[key]); err != nil {
return 0, err return 0, err
} }
} }
}
sort.Sort(gguf.Tensors) sort.Sort(gguf.Tensors)
var s uint64 var s uint64
for _, t := range gguf.Tensors { for _, t := range gguf.Tensors {
t.Offset = s t.Offset = s
if err := ggufWriteTensorInfo(w, t); err != nil { if err := ggufWriteTensorInfo(wo, t); err != nil {
return 0, err return 0, err
} }
s += t.Size() s += t.Size()
} }
var alignment int64 = 32
for _, t := range gguf.Tensors { for _, t := range gguf.Tensors {
if err := ggufWriteTensor(w, t, alignment); err != nil { if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
return 0, err return 0, err
} }
} }
@ -762,7 +781,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
return 0, nil return 0, nil
} }
func ggufWriteTensorInfo(ws io.Writer, t *Tensor, alignment int64) error { func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
return err return err
} }
@ -788,22 +807,17 @@ func ggufWriteTensorInfo(ws io.Writer, t *Tensor, alignment int64) error {
return binary.Write(ws, binary.LittleEndian, t.Offset) return binary.Write(ws, binary.LittleEndian, t.Offset)
} }
func ggufWriteTensor(ws io.Writer, t *Tensor, alignment int64) error { func ggufWriteTensor(ws io.Writer, t *Tensor, offset int) error {
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset) slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
offset, err := ws.Seek(0, io.SeekCurrent) if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(int64(offset), 32)))); err != nil {
if err != nil {
return err return err
} }
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil { _, err := t.WriteTo(ws)
return err return err
} }
_, err = t.WriteTo(ws) func ggufWriteKV(ws io.Writer, k string, v any) error {
return err
}
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
slog.Debug(k, "type", fmt.Sprintf("%T", v)) slog.Debug(k, "type", fmt.Sprintf("%T", v))
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
return err return err
@ -851,9 +865,44 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
return err return err
} }
} }
case *array:
if v.size > 0 {
switch v.values[0].(type) {
case string:
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, uint64(v.size)); err != nil {
return err
}
for _, e := range v.values {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e.(string)))); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, []byte(e.(string))); err != nil {
return err
}
}
default: default:
err = writeGGUFArray(ws, v.datatype, v.values)
}
}
default:
fmt.Println("type is", v)
return fmt.Errorf("improper type for '%s'", k) return fmt.Errorf("improper type for '%s'", k)
} }
return err return err
} }
func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align
}

View File

@ -228,7 +228,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
var offset int64 var offset int64
for offset < stat.Size() { for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file, 0) ggml, n, err := llm.DecodeGGML(file, -1)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {