ggml: Support heterogeneous KV cache layer sizes in memory estimation

Gemma3 uses sliding windows for its context on 5/6 layers, significantly
reducing memory usage but leading to uneven usage across layers,
which makes allocation to the correct GPU difficult. We currently
estimate very conservatively by assuming all layers are consistent
at the max size.

Llama3.2-vision is also inconsistent between self attention and cross
attention layers - at moment, we calculate the correct total size
and then average this across layers. In some cases, this may lead
to crashes if a large layer is placed on a GPU sized by the average.

This allows memory estimation to calculate per-layer KV cache size
and take this account when placing layers onto GPUs. We already do
this for weights that vary per-tensor, so this is a logical extension.

Fixes #9730
Fixes #9890
This commit is contained in:
Jesse Gross 2025-03-24 13:39:07 -07:00 committed by Jesse Gross
parent f4f0992b6e
commit f66216e399
5 changed files with 49 additions and 28 deletions

View File

@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil }, offset, nil
} }
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV() headsKV := f.KV().HeadCountKV()
@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
layers := f.Tensors().GroupLayers() layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType) bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
switch f.KV().Architecture() { switch f.KV().Architecture() {
case "llama": case "llama":
@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
case "mllama": case "mllama":
var visionTokens, tiles uint64 = 1601, 4 var visionTokens, tiles uint64 = 1601, 4
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok { crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
kv = headsKV * for i := range kv {
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V if slices.Contains(crossAttentionLayers, uint32(i)) {
(2* // sizeof(float16) kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
context +
4 * // sizeof(float32) 4 * // sizeof(float32)
uint64(crossAttentionLayers.size)* // num cross attention layers
visionTokens * visionTokens *
tiles) tiles
}
} }
fullOffload = max( fullOffload = max(
@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
4*embeddingHeadsK*context*8+ 4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16, embedding*embeddingHeadsK*heads*9/16,
) )
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {
const gemma3GlobalCacheCount = 6
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
for i := range kv {
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
}
}
case "command-r": case "command-r":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),

View File

@ -15,12 +15,12 @@ import (
) )
// This algorithm looks for a complete fit to determine if we need to unload other models // This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) { func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them // Split up the GPUs by type and try them
var estimatedVRAM uint64 var estimatedVRAM uint64
for _, gpus := range allGpus.ByLibrary() { for _, gpus := range allGpus.ByLibrary() {
var layerCount int var layerCount int
estimate := EstimateGPULayers(gpus, f, projectors, opts) estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
if opts.NumGPU < 0 { if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) { if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
@ -71,7 +71,7 @@ type MemoryEstimate struct {
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library // The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate { func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
// Graph size for a partial offload, applies to all GPUs // Graph size for a partial offload, applies to all GPUs
var graphPartialOffload uint64 var graphPartialOffload uint64
@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
} }
} }
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct) kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
// KV is proportional to the number of layers if len(kv) > 0 {
layerSize += kv / f.KV().BlockCount() layerSize += kv[0]
}
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 { if graphPartialOffload == 0 {
graphPartialOffload = f.KV().GQA() * kv / 6 graphPartialOffload = f.KV().GQA() * kvTotal / 6
} }
if graphFullOffload == 0 { if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload graphFullOffload = graphPartialOffload
@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// Some models have inconsistent layer sizes // Some models have inconsistent layer sizes
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size() layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount() layerSize += kv[i]
memoryWeights += blk.Size() memoryWeights += blk.Size()
} }
@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
layersRequested: opts.NumGPU, layersRequested: opts.NumGPU,
layersModel: int(f.KV().BlockCount()) + 1, layersModel: int(f.KV().BlockCount()) + 1,
availableList: availableList, availableList: availableList,
kv: kv, kv: kvTotal,
allocationsList: allocationsList, allocationsList: allocationsList,
memoryWeights: memoryWeights, memoryWeights: memoryWeights,
memoryLayerOutput: memoryLayerOutput, memoryLayerOutput: memoryLayerOutput,

View File

@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
projectors := []string{} projectors := []string{}
opts := api.DefaultOptions() opts := api.DefaultOptions()
t.Run("cpu", func(t *testing.T) { t.Run("cpu", func(t *testing.T) {
estimate := EstimateGPULayers(gpus, ggml, projectors, opts) estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, 0, estimate.Layers) assert.Equal(t, 0, estimate.Layers)
assert.Equal(t, uint64(0), estimate.Graph) assert.Equal(t, uint64(0), estimate.Graph)
}) })
@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1 gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload) gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload) gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts) estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s) assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s) assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
var layerSums uint64 var layerSums uint64

View File

@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
gpus = discover.GetCPUInfo() gpus = discover.GetCPUInfo()
} }
estimate := EstimateGPULayers(gpus, f, projectors, opts) estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
if len(gpus) > 1 || gpus[0].Library != "cpu" { if len(gpus) > 1 || gpus[0].Library != "cpu" {
switch { switch {
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory: case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:

View File

@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
req.opts.NumCtx = req.origNumCtx * p req.opts.NumCtx = req.origNumCtx * p
if !envconfig.SchedSpread() { if !envconfig.SchedSpread() {
for _, g := range sgl { for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p *numParallel = p
return []discover.GpuInfo{g} return []discover.GpuInfo{g}
@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
// Now try all the GPUs // Now try all the GPUs
for _, p := range numParallelToTry { for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM)) slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p *numParallel = p
return sgl return sgl
@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
var bestEstimate uint64 var bestEstimate uint64
var bestFit int var bestFit int
for i, gl := range byLibrary { for i, gl := range byLibrary {
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) _, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
if estimatedVRAM > bestEstimate { if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM bestEstimate = estimatedVRAM
bestFit = i bestFit = i
@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) {
// If not, pick a runner to unload, else return nil and the request can be loaded // If not, pick a runner to unload, else return nil and the request can be loaded
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef { func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
slog.Debug("evaluating if CPU model load will fit in available system memory") slog.Debug("evaluating if CPU model load will fit in available system memory")
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts) estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
if estimate.TotalSize <= gpus[0].FreeMemory { if estimate.TotalSize <= gpus[0].FreeMemory {
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory)) slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
return nil return nil