concurrent load tensors
This commit is contained in:
parent
b7943d941d
commit
e699b8f5b9
@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -24,9 +24,9 @@ type Backend interface {
|
||||
NewContext() Context
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(io.ReadSeeker) (Backend, error))
|
||||
var backends = make(map[string]func(*os.File) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
|
||||
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
@ -34,9 +34,9 @@ func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(r io.ReadSeeker) (Backend, error) {
|
||||
func NewBackend(f *os.File) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
return backend(r)
|
||||
return backend(f)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
|
@ -12,8 +12,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@ -28,7 +31,7 @@ type Backend struct {
|
||||
ggml.Tensors
|
||||
}
|
||||
|
||||
func New(r io.ReadSeeker) (ml.Backend, error) {
|
||||
func New(r *os.File) (ml.Backend, error) {
|
||||
f, _, err := ggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -62,22 +65,20 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
|
||||
|
||||
b := newBackend()
|
||||
bb := C.ggml_backend_alloc_ctx_tensors(c, b)
|
||||
|
||||
var g errgroup.Group
|
||||
for _, t := range f.Tensors().Items {
|
||||
if _, err := r.Seek(int64(f.Tensors().Offset+t.Offset), io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g.Go(func() error {
|
||||
var b bytes.Buffer
|
||||
n, err := io.Copy(&b, io.NewSectionReader(r, int64(f.Tensors().Offset+t.Offset), int64(t.Size())))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
n, err := io.CopyN(&b, r, int64(t.Size()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != int64(t.Size()) {
|
||||
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
|
||||
}
|
||||
|
||||
if n != int64(t.Size()) {
|
||||
return nil, fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
|
||||
}
|
||||
|
||||
func() {
|
||||
cname := C.CString(t.Name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
@ -85,7 +86,12 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
|
||||
defer C.free(cbytes)
|
||||
|
||||
C.ggml_backend_tensor_set(C.ggml_get_tensor(c, cname), cbytes, 0, C.size_t(n))
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Backend{c, b, bb, f.KV(), f.Tensors()}, nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user