model: load non-repeated tensors into multiple backends

some tensors are expected to be used in repeating layers but are not
themselves repeated. this change copies these tensors into the same
backends as their repeating counterparts to minimize copying tensors
between backends
This commit is contained in:
Michael Yang 2025-02-24 15:48:42 -08:00
parent bab6f34dc0
commit bfce55db3d
2 changed files with 58 additions and 42 deletions

View File

@ -25,11 +25,13 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml" fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
func devices() iter.Seq[*C.struct_ggml_backend_device] { func devices() iter.Seq[*C.struct_ggml_backend_device] {
return func(yield func(*C.struct_ggml_backend_device) bool) { return func(yield func(*C.struct_ggml_backend_device) bool) {
ggml.OnceLoad()
for i := range C.ggml_backend_dev_count() { for i := range C.ggml_backend_dev_count() {
if !yield(C.ggml_backend_dev_get(i)) { if !yield(C.ggml_backend_dev_get(i)) {
return return
@ -146,8 +148,15 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
slog.Info("max tensors", "max_tensors", maxTensors) slog.Info("max tensors", "max_tensors", maxTensors)
type tensor struct {
source *fs.Tensor
target string
}
targets := make(map[string][]string)
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
createTensor := func(t *fs.Tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
for _, bt := range bts { for _, bt := range bts {
if _, ok := ctxs[bt]; !ok { if _, ok := ctxs[bt]; !ok {
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
@ -156,16 +165,23 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}) })
} }
cname := C.CString(t.Name) targets[t.source.Name] = append(targets[t.source.Name], t.target)
name := t.source.Name
if t.target != "" {
name = t.target
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname)) defer C.free(unsafe.Pointer(cname))
if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil { if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
return tt return tt
} }
tt := C.ggml_new_tensor(ctxs[bt], t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0]))) tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
C.ggml_set_name(tt, cname) C.ggml_set_name(tt, cname)
slog.Debug("created tensor", "name", t.Name, "shape", t.Shape, "dtype", t.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
//nolint:staticcheck // TODO: check if buffer type supports this tensor //nolint:staticcheck // TODO: check if buffer type supports this tensor
return tt return tt
} }
@ -187,9 +203,9 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
switch { switch {
case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
createTensor(t, input.bts) createTensor(tensor{source: t}, input.bts)
case hasPart(t.Name, "cls", "output", "output_norm"): case hasPart(t.Name, "cls", "output", "output_norm"):
createTensor(t, output.bts) createTensor(tensor{source: t}, output.bts)
default: default:
if i := func() int { if i := func() int {
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 { if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
@ -200,10 +216,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return -1 return -1
}(); i >= 0 { }(); i >= 0 {
createTensor(t, layers[i].bts) createTensor(tensor{source: t}, layers[i].bts)
} else { } else {
for _, layer := range layers { for i, layer := range layers {
createTensor(t, layer.bts) createTensor(tensor{
source: t,
target: "blk." + strconv.Itoa(i) + "." + t.Name,
}, layer.bts)
} }
} }
} }
@ -237,8 +256,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group var g errgroup.Group
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] {
g.Go(func() error { g.Go(func() error {
tt, ok := tensors[t.Name] if target == "" {
target = t.Name
}
tt, ok := tensors[target]
if !ok { if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name) return fmt.Errorf("unassigned tensor: %s", t.Name)
} }
@ -260,6 +284,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return nil return nil
}) })
} }
}
if g.Wait() != nil { if g.Wait() != nil {
return nil, err return nil, err

View File

@ -207,7 +207,13 @@ struct ggml_backend_registry {
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
register_device(ggml_backend_reg_dev_get(reg, i), score); register_device(ggml_backend_reg_dev_get(reg, i), score);
} }
}
void register_device(ggml_backend_dev_t device, int score = -1) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif
devices.push_back({device, score});
std::stable_sort(devices.begin(), devices.end(), std::stable_sort(devices.begin(), devices.end(),
[](const auto & a, const auto & b) { [](const auto & a, const auto & b) {
return a.second > b.second; return a.second > b.second;
@ -215,21 +221,6 @@ struct ggml_backend_registry {
); );
} }
void register_device(ggml_backend_dev_t device, int score = -1) {
switch (ggml_backend_dev_type(device)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_GPU:
score += 1 << 16;
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
score += 1 << 20;
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif
devices.push_back({device, score});
}
ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) { ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
dl_handle_ptr handle { dl_load_library(path) }; dl_handle_ptr handle { dl_load_library(path) };
if (!handle) { if (!handle) {