ml/backend/ggml: offload vision to cpu

temporary until tensor loading can accurately account for vision models
This commit is contained in:
Michael Yang 2025-02-27 16:46:01 -08:00
parent b5312f30e8
commit 2dc60d4620

View File

@ -134,13 +134,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
cpuDeviceBufferTypes := deviceBufferType{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes} cpuDeviceBufferTypes := deviceBufferType{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
input := cpuDeviceBufferTypes input := cpuDeviceBufferTypes
var blocks int blocks := int(meta.KV().BlockCount())
for key, value := range meta.KV() {
if strings.HasSuffix(key, ".block_count") {
blocks += int(value.(uint32))
}
}
assignLayer := func(i int) (temp deviceBufferType) { assignLayer := func(i int) (temp deviceBufferType) {
if i >= params.NumGPULayers { if i >= params.NumGPULayers {
return cpuDeviceBufferTypes return cpuDeviceBufferTypes
@ -206,7 +200,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return nil return nil
} }
hasPart := func(s string, parts ...string) bool { contains := func(s string, parts ...string) bool {
split := strings.Split(s, ".") split := strings.Split(s, ".")
for _, part := range parts { for _, part := range parts {
if slices.Contains(split, part) { if slices.Contains(split, part) {
@ -219,10 +213,12 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
switch { switch {
case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
createTensor(tensor{source: t}, input.bts) createTensor(tensor{source: t}, input.bts)
case hasPart(t.Name, "cls", "output", "output_norm"): case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts) createTensor(tensor{source: t}, output.bts)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
createTensor(tensor{source: t}, input.bts)
default: default:
if i := func() int { if i := func() int {
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 { if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {