TensorWriter

This commit is contained in:
Josh Yan 2024-07-12 12:18:46 -07:00
parent 554f3bdc0e
commit 3d0fd31f0e
3 changed files with 30 additions and 5 deletions

View File

@ -476,3 +476,11 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
return
}
type TensorWriter struct {
io.Reader
}
func (tw TensorWriter) WriteTo(w io.Writer) (int64, error) {
return io.Copy(w, tw.Reader)
}

View File

@ -723,7 +723,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.T))); err != nil {
if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil {
return 0, err
}
@ -736,7 +736,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
for _, key := range keys {
if err := ggufWriteKV(w, key, gguf.KV[key]); err != nil {
return err
return 0, err
}
}
@ -762,9 +762,13 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
}
func ggufWriteTensor(io.Writer, *Tensor, int64) error {
return nil
}
func ggufWriteTensorInfo(io.Writer, *Tensor) error {
return nil
}
func ggufWriteKV(io.Writer, string, any) error {
return nil
}

View File

@ -245,11 +245,24 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
var reader io.Reader = io.NewSectionReader(file, offset, n)
if !sort.IsSorted(ggml.Tensors()) {
// create a new Tensors containing Tensors that have a writeTo
var tensors llm.Tensors
for _, tensor := range ggml.Tensors() {
tensors = append(tensors, &llm.Tensor{
Name: tensor.Name,
Kind: tensor.Kind,
Shape: tensor.Shape,
WriterTo: &llm.TensorWriter{
Reader: io.NewSectionReader(file, int64(tensor.Offset), int64(tensor.Size())),
},
})
}
reader = &llm.GGUFWriter{
KV: ggml.KV(),
// Update .Tensors
Tensors: ggml.Tensors(),
Tensors: tensors,
}
}