ml/backend/ggml: handle tensor split
This commit is contained in:
parent
26c2e0bd35
commit
b5312f30e8
@ -93,16 +93,8 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var sum uint64
|
||||
var cumsum []uint64
|
||||
|
||||
var gpuDeviceBufferTypes []deviceBufferType
|
||||
for _, d := range gpus {
|
||||
var free, total C.size_t
|
||||
C.ggml_backend_dev_memory(d, &free, &total)
|
||||
sum += uint64(free)
|
||||
cumsum = append(cumsum, sum)
|
||||
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||
d: d,
|
||||
@ -110,9 +102,33 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
})
|
||||
}
|
||||
|
||||
splits := make([]float64, len(cumsum))
|
||||
splits := make([]float32, len(gpus))
|
||||
if func() bool {
|
||||
for _, s := range params.TensorSplit {
|
||||
if s != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}() {
|
||||
splits = params.TensorSplit
|
||||
} else {
|
||||
for i := range splits {
|
||||
var free, total C.size_t
|
||||
C.ggml_backend_dev_memory(gpus[i], &free, &total)
|
||||
splits[i] = float32(free)
|
||||
}
|
||||
}
|
||||
|
||||
var sum float32
|
||||
for i := range splits {
|
||||
splits[i] = float64(cumsum[i]) / float64(sum)
|
||||
sum += splits[i]
|
||||
splits[i] = sum
|
||||
}
|
||||
|
||||
for i := range splits {
|
||||
splits[i] /= sum
|
||||
}
|
||||
|
||||
cpuDeviceBufferTypes := deviceBufferType{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
|
||||
@ -130,9 +146,12 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return cpuDeviceBufferTypes
|
||||
}
|
||||
|
||||
return gpuDeviceBufferTypes[slices.IndexFunc(splits, func(f float64) bool {
|
||||
return float64(i)/float64(blocks+1) < f
|
||||
})]
|
||||
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i)/float32(blocks+1) < f })
|
||||
if index < 0 || index >= len(gpuDeviceBufferTypes) {
|
||||
return cpuDeviceBufferTypes
|
||||
}
|
||||
|
||||
return gpuDeviceBufferTypes[index]
|
||||
}
|
||||
|
||||
layers := make([]deviceBufferType, blocks)
|
||||
|
Loading…
x
Reference in New Issue
Block a user