concurrent load tensors

This commit is contained in:
Michael Yang 2024-12-03 14:35:07 -08:00
parent b7943d941d
commit e699b8f5b9
2 changed files with 26 additions and 20 deletions

View File

@ -4,7 +4,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
"io"
"os"
"strings"
)
@ -24,9 +24,9 @@ type Backend interface {
NewContext() Context
}
var backends = make(map[string]func(io.ReadSeeker) (Backend, error))
var backends = make(map[string]func(*os.File) (Backend, error))
func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@ -34,9 +34,9 @@ func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
backends[name] = f
}
func NewBackend(r io.ReadSeeker) (Backend, error) {
func NewBackend(f *os.File) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(r)
return backend(f)
}
return nil, fmt.Errorf("unsupported backend")

View File

@ -12,8 +12,11 @@ import (
"fmt"
"io"
"log/slog"
"os"
"unsafe"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
@ -28,7 +31,7 @@ type Backend struct {
ggml.Tensors
}
func New(r io.ReadSeeker) (ml.Backend, error) {
func New(r *os.File) (ml.Backend, error) {
f, _, err := ggml.Decode(r, -1)
if err != nil {
return nil, err
@ -62,22 +65,20 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
b := newBackend()
bb := C.ggml_backend_alloc_ctx_tensors(c, b)
var g errgroup.Group
for _, t := range f.Tensors().Items {
if _, err := r.Seek(int64(f.Tensors().Offset+t.Offset), io.SeekStart); err != nil {
return nil, err
}
g.Go(func() error {
var b bytes.Buffer
n, err := io.Copy(&b, io.NewSectionReader(r, int64(f.Tensors().Offset+t.Offset), int64(t.Size())))
if err != nil {
return err
}
var b bytes.Buffer
n, err := io.CopyN(&b, r, int64(t.Size()))
if err != nil {
return nil, err
}
if n != int64(t.Size()) {
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
}
if n != int64(t.Size()) {
return nil, fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
}
func() {
cname := C.CString(t.Name)
defer C.free(unsafe.Pointer(cname))
@ -85,7 +86,12 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
defer C.free(cbytes)
C.ggml_backend_tensor_set(C.ggml_get_tensor(c, cname), cbytes, 0, C.size_t(n))
}()
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return &Backend{c, b, bb, f.KV(), f.Tensors()}, nil