diff --git a/ml/backend.go b/ml/backend.go index b9efad8c2..51d93f628 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" - "io" + "os" "strings" ) @@ -24,9 +24,9 @@ type Backend interface { NewContext() Context } -var backends = make(map[string]func(io.ReadSeeker) (Backend, error)) +var backends = make(map[string]func(*os.File) (Backend, error)) -func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) { +func RegisterBackend(name string, f func(*os.File) (Backend, error)) { if _, ok := backends[name]; ok { panic("backend: backend already registered") } @@ -34,9 +34,9 @@ func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) { backends[name] = f } -func NewBackend(r io.ReadSeeker) (Backend, error) { +func NewBackend(f *os.File) (Backend, error) { if backend, ok := backends["ggml"]; ok { - return backend(r) + return backend(f) } return nil, fmt.Errorf("unsupported backend") diff --git a/ml/backend/ggml/backend.go b/ml/backend/ggml/backend.go index 50a01e818..982096932 100644 --- a/ml/backend/ggml/backend.go +++ b/ml/backend/ggml/backend.go @@ -12,8 +12,11 @@ import ( "fmt" "io" "log/slog" + "os" "unsafe" + "golang.org/x/sync/errgroup" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" @@ -28,7 +31,7 @@ type Backend struct { ggml.Tensors } -func New(r io.ReadSeeker) (ml.Backend, error) { +func New(r *os.File) (ml.Backend, error) { f, _, err := ggml.Decode(r, -1) if err != nil { return nil, err @@ -62,22 +65,20 @@ func New(r io.ReadSeeker) (ml.Backend, error) { b := newBackend() bb := C.ggml_backend_alloc_ctx_tensors(c, b) + + var g errgroup.Group for _, t := range f.Tensors().Items { - if _, err := r.Seek(int64(f.Tensors().Offset+t.Offset), io.SeekStart); err != nil { - return nil, err - } + g.Go(func() error { + var b bytes.Buffer + n, err := io.Copy(&b, io.NewSectionReader(r, int64(f.Tensors().Offset+t.Offset), int64(t.Size()))) + if err != nil { + return err + } - var b bytes.Buffer - n, err := io.CopyN(&b, r, int64(t.Size())) - if err != nil { - return nil, err - } + if n != int64(t.Size()) { + return fmt.Errorf("expected %d bytes, got %d", t.Size(), n) + } - if n != int64(t.Size()) { - return nil, fmt.Errorf("expected %d bytes, got %d", t.Size(), n) - } - - func() { cname := C.CString(t.Name) defer C.free(unsafe.Pointer(cname)) @@ -85,7 +86,12 @@ func New(r io.ReadSeeker) (ml.Backend, error) { defer C.free(cbytes) C.ggml_backend_tensor_set(C.ggml_get_tensor(c, cname), cbytes, 0, C.size_t(n)) - }() + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err } return &Backend{c, b, bb, f.KV(), f.Tensors()}, nil