This commit is contained in:
Josh Yan 2024-07-12 09:42:10 -07:00
parent 2fdebffc8d
commit e75fb73839
4 changed files with 27 additions and 43 deletions

View File

@ -1,7 +1,6 @@
package llm
import (
"cmp"
"encoding/binary"
"errors"
"fmt"
@ -113,26 +112,28 @@ func (kv KV) ChatTemplate() string {
return s
}
// Tensors type as a slice of pointers to Tensor
type Tensors []*Tensor
func (ts Tensors) Less(i, j int) bool {
var x, y int
if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
return cmp.Less(ts[i].Name, ts[j].Name)
} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
return cmp.Less(ts[i].Name, ts[j].Name)
}
return cmp.Less(x, y)
}
// Implement the Len method
func (ts Tensors) Len() int {
return len(ts)
}
// Implement the Swap method
func (ts Tensors) Swap(i, j int) {
var temp Tensor
ts[i], ts[j] = ts[j], ts[i]
}
// Implement the Less method
func (ts Tensors) Less(i, j int) bool {
var x, y int
if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
return ts[i].Name < ts[j].Name
} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
return ts[i].Name < ts[j].Name
}
return x < y
}
func (ts Tensors) Layers() map[string]Layer {

View File

@ -2,12 +2,12 @@ package llm
import (
"bytes"
"cmp"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"slices"
"sort"
"strings"
"golang.org/x/exp/maps"
@ -702,8 +702,8 @@ func (gguf) padding(offset, align int64) int64 {
// Reader and WriterTo
type GGUFWriter struct {
KV KV
T []*Tensor
KV
Tensors
}
var _ io.Reader = (*GGUFWriter)(nil)
@ -740,19 +740,10 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
}
}
slices.SortFunc(gguf.T, func(a, b *Tensor) int {
var i, j int
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
}
return cmp.Compare(i, j)
})
sort.Sort(gguf.Tensors)
var s uint64
for _, t := range gguf.T {
for _, t := range gguf.Tensors {
t.Offset = s
if err := ggufWriteTensorInfo(w, t); err != nil {
return 0, err
@ -761,7 +752,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
}
var alignment int64 = 32
for _, t := range gguf.T {
for _, t := range gguf.Tensors {
if err := ggufWriteTensor(w, t, alignment); err != nil {
return 0, err
}

View File

@ -29,6 +29,8 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
defer os.Remove(temp.Name())
sha256sum := sha256.New()
if
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil {
return nil, err

View File

@ -3,7 +3,6 @@ package server
import (
"archive/zip"
"bytes"
"cmp"
"context"
"errors"
"fmt"
@ -12,7 +11,7 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"sort"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
@ -244,19 +243,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
}
var reader io.Reader = io.NewSectionReader(file, offset, n)
if !slices.IsSortedFunc(ggml.Tensors(), func(a, b *llm.Tensor) int {
var i, j int
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
}
return cmp.Compare(i, j)
}) {
if !sort.IsSorted(ggml.Tensors()) {
reader = &llm.GGUFWriter{
KV: ggml.KV(),
T: ggml.Tensors(),
KV: ggml.KV(),
Tensors: ggml.Tensors(),
}
}