This commit is contained in:
Josh Yan 2024-07-12 09:42:10 -07:00
parent 2fdebffc8d
commit e75fb73839
4 changed files with 27 additions and 43 deletions

View File

@ -1,7 +1,6 @@
package llm package llm
import ( import (
"cmp"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -113,26 +112,28 @@ func (kv KV) ChatTemplate() string {
return s return s
} }
// Tensors type as a slice of pointers to Tensor
type Tensors []*Tensor type Tensors []*Tensor
func (ts Tensors) Less(i, j int) bool { // Implement the Len method
var x, y int
if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
return cmp.Less(ts[i].Name, ts[j].Name)
} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
return cmp.Less(ts[i].Name, ts[j].Name)
}
return cmp.Less(x, y)
}
func (ts Tensors) Len() int { func (ts Tensors) Len() int {
return len(ts) return len(ts)
} }
// Implement the Swap method
func (ts Tensors) Swap(i, j int) { func (ts Tensors) Swap(i, j int) {
var temp Tensor ts[i], ts[j] = ts[j], ts[i]
}
// Implement the Less method
func (ts Tensors) Less(i, j int) bool {
var x, y int
if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
return ts[i].Name < ts[j].Name
} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
return ts[i].Name < ts[j].Name
}
return x < y
} }
func (ts Tensors) Layers() map[string]Layer { func (ts Tensors) Layers() map[string]Layer {

View File

@ -2,12 +2,12 @@ package llm
import ( import (
"bytes" "bytes"
"cmp"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"slices" "slices"
"sort"
"strings" "strings"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -702,8 +702,8 @@ func (gguf) padding(offset, align int64) int64 {
// Reader and WriterTo // Reader and WriterTo
type GGUFWriter struct { type GGUFWriter struct {
KV KV KV
T []*Tensor Tensors
} }
var _ io.Reader = (*GGUFWriter)(nil) var _ io.Reader = (*GGUFWriter)(nil)
@ -740,19 +740,10 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
} }
} }
slices.SortFunc(gguf.T, func(a, b *Tensor) int { sort.Sort(gguf.Tensors)
var i, j int
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
}
return cmp.Compare(i, j)
})
var s uint64 var s uint64
for _, t := range gguf.T { for _, t := range gguf.Tensors {
t.Offset = s t.Offset = s
if err := ggufWriteTensorInfo(w, t); err != nil { if err := ggufWriteTensorInfo(w, t); err != nil {
return 0, err return 0, err
@ -761,7 +752,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
} }
var alignment int64 = 32 var alignment int64 = 32
for _, t := range gguf.T { for _, t := range gguf.Tensors {
if err := ggufWriteTensor(w, t, alignment); err != nil { if err := ggufWriteTensor(w, t, alignment); err != nil {
return 0, err return 0, err
} }

View File

@ -29,6 +29,8 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
sha256sum := sha256.New() sha256sum := sha256.New()
if
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -3,7 +3,6 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"cmp"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -12,7 +11,7 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"slices" "sort"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
@ -244,19 +243,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
} }
var reader io.Reader = io.NewSectionReader(file, offset, n) var reader io.Reader = io.NewSectionReader(file, offset, n)
if !slices.IsSortedFunc(ggml.Tensors(), func(a, b *llm.Tensor) int { if !sort.IsSorted(ggml.Tensors()) {
var i, j int
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
return cmp.Compare(a.Name, b.Name)
}
return cmp.Compare(i, j)
}) {
reader = &llm.GGUFWriter{ reader = &llm.GGUFWriter{
KV: ggml.KV(), KV: ggml.KV(),
T: ggml.Tensors(), Tensors: ggml.Tensors(),
} }
} }