diff --git a/llm/ggml.go b/llm/ggml.go index c714011c7..1fdd3c071 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -476,3 +476,11 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui return } + +type TensorWriter struct { + io.Reader +} + +func (tw TensorWriter) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, tw.Reader) +} diff --git a/llm/gguf.go b/llm/gguf.go index 745826942..ccbaa1c2a 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -723,7 +723,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { return 0, err } - if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.T))); err != nil { + if err := binary.Write(w, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil { return 0, err } @@ -736,7 +736,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { for _, key := range keys { if err := ggufWriteKV(w, key, gguf.KV[key]); err != nil { - return err + return 0, err } } @@ -762,9 +762,13 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } func ggufWriteTensor(io.Writer, *Tensor, int64) error { - + return nil } func ggufWriteTensorInfo(io.Writer, *Tensor) error { - + return nil +} + +func ggufWriteKV(io.Writer, string, any) error { + return nil } diff --git a/server/model.go b/server/model.go index a4c1bfab2..b83166cda 100644 --- a/server/model.go +++ b/server/model.go @@ -245,11 +245,24 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap var reader io.Reader = io.NewSectionReader(file, offset, n) if !sort.IsSorted(ggml.Tensors()) { // create a new Tensors containing Tensors that have a writeTo + var tensors llm.Tensors + + for _, tensor := range ggml.Tensors() { + tensors = append(tensors, &llm.Tensor{ + Name: tensor.Name, + Kind: tensor.Kind, + Shape: tensor.Shape, + + WriterTo: &llm.TensorWriter{ + Reader: io.NewSectionReader(file, int64(tensor.Offset), int64(tensor.Size())), + }, + }) + } reader = &llm.GGUFWriter{ KV: ggml.KV(), // Update .Tensors - Tensors: ggml.Tensors(), + Tensors: tensors, } }