From e75fb73839c4fd3b7abbde1b6c9a274898c9277b Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Fri, 12 Jul 2024 09:42:10 -0700 Subject: [PATCH] types --- llm/ggml.go | 29 +++++++++++++++-------------- llm/gguf.go | 21 ++++++--------------- server/layer.go | 2 ++ server/model.go | 18 ++++-------------- 4 files changed, 27 insertions(+), 43 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index 126139641..c714011c7 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -1,7 +1,6 @@ package llm import ( - "cmp" "encoding/binary" "errors" "fmt" @@ -113,26 +112,28 @@ func (kv KV) ChatTemplate() string { return s } +// Tensors type as a slice of pointers to Tensor type Tensors []*Tensor -func (ts Tensors) Less(i, j int) bool { - var x, y int - if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 { - return cmp.Less(ts[i].Name, ts[j].Name) - } else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 { - return cmp.Less(ts[i].Name, ts[j].Name) - } - - return cmp.Less(x, y) -} - +// Implement the Len method func (ts Tensors) Len() int { return len(ts) } +// Implement the Swap method func (ts Tensors) Swap(i, j int) { - var temp Tensor - + ts[i], ts[j] = ts[j], ts[i] +} + +// Implement the Less method +func (ts Tensors) Less(i, j int) bool { + var x, y int + if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 { + return ts[i].Name < ts[j].Name + } else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 { + return ts[i].Name < ts[j].Name + } + return x < y } func (ts Tensors) Layers() map[string]Layer { diff --git a/llm/gguf.go b/llm/gguf.go index 4a3f23e51..745826942 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -2,12 +2,12 @@ package llm import ( "bytes" - "cmp" "encoding/binary" "encoding/json" "fmt" "io" "slices" + "sort" "strings" "golang.org/x/exp/maps" @@ -702,8 +702,8 @@ func (gguf) padding(offset, align int64) int64 { // Reader and WriterTo type GGUFWriter struct { - KV KV - T []*Tensor + KV + Tensors } var _ io.Reader = (*GGUFWriter)(nil) @@ -740,19 +740,10 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } } - slices.SortFunc(gguf.T, func(a, b *Tensor) int { - var i, j int - if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) - } else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) - } - - return cmp.Compare(i, j) - }) + sort.Sort(gguf.Tensors) var s uint64 - for _, t := range gguf.T { + for _, t := range gguf.Tensors { t.Offset = s if err := ggufWriteTensorInfo(w, t); err != nil { return 0, err @@ -761,7 +752,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) { } var alignment int64 = 32 - for _, t := range gguf.T { + for _, t := range gguf.Tensors { if err := ggufWriteTensor(w, t, alignment); err != nil { return 0, err } diff --git a/server/layer.go b/server/layer.go index cc6709d24..d4c56ee6d 100644 --- a/server/layer.go +++ b/server/layer.go @@ -29,6 +29,8 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { defer os.Remove(temp.Name()) sha256sum := sha256.New() + if + n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) if err != nil { return nil, err diff --git a/server/model.go b/server/model.go index c200b909b..ec7e94449 100644 --- a/server/model.go +++ b/server/model.go @@ -3,7 +3,6 @@ package server import ( "archive/zip" "bytes" - "cmp" "context" "errors" "fmt" @@ -12,7 +11,7 @@ import ( "net/http" "os" "path/filepath" - "slices" + "sort" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -244,19 +243,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap } var reader io.Reader = io.NewSectionReader(file, offset, n) - if !slices.IsSortedFunc(ggml.Tensors(), func(a, b *llm.Tensor) int { - var i, j int - if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) - } else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) - } - - return cmp.Compare(i, j) - }) { + if !sort.IsSorted(ggml.Tensors()) { reader = &llm.GGUFWriter{ - KV: ggml.KV(), - T: ggml.Tensors(), + KV: ggml.KV(), + Tensors: ggml.Tensors(), } }