From f66216e3990b73869341c58ac9561b26c468c558 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 24 Mar 2025 13:39:07 -0700 Subject: [PATCH] ggml: Support heterogeneous KV cache layer sizes in memory estimation Gemma3 uses sliding windows for its context on 5/6 layers, significantly reducing memory usage but leading to uneven usage across layers, which makes allocation to the correct GPU difficult. We currently estimate very conservatively by assuming all layers are consistent at the max size. Llama3.2-vision is also inconsistent between self attention and cross attention layers - at moment, we calculate the correct total size and then average this across layers. In some cases, this may lead to crashes if a large layer is placed on a GPU sized by the average. This allows memory estimation to calculate per-layer KV cache size and take this account when placing layers onto GPUs. We already do this for weights that vary per-tensor, so this is a logical extension. Fixes #9730 Fixes #9890 --- fs/ggml/ggml.go | 39 +++++++++++++++++++++++++++------------ llm/memory.go | 24 +++++++++++++++--------- llm/memory_test.go | 4 ++-- llm/server.go | 2 +- server/sched.go | 8 ++++---- 5 files changed, 49 insertions(+), 28 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 0be69e82d..c88583fb8 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { }, offset, nil } -func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) { +func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCount() headsKV := f.KV().HeadCountKV() @@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO layers := f.Tensors().GroupLayers() bytesPerElement := kvCacheBytesPerElement(kvCacheType) - kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + kv = make([]uint64, f.KV().BlockCount()) + for i := range kv { + kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + } switch f.KV().Architecture() { case "llama": @@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO case "mllama": var visionTokens, tiles uint64 = 1601, 4 - if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok { - kv = headsKV * - (embeddingHeadsK + embeddingHeadsV) * // one for K, one for V - (2* // sizeof(float16) - (f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers - context + - 4* // sizeof(float32) - uint64(crossAttentionLayers.size)* // num cross attention layers - visionTokens* - tiles) + crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") + for i := range kv { + if slices.Contains(crossAttentionLayers, uint32(i)) { + kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * + 4 * // sizeof(float32) + visionTokens * + tiles + } } fullOffload = max( @@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO 4*embeddingHeadsK*context*8+ embedding*embeddingHeadsK*heads*9/16, ) + + // Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama + // engine. Gemma3 always uses the Ollama engine. + if f.KV().Architecture() == "gemma3" { + const gemma3GlobalCacheCount = 6 + slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch + for i := range kv { + // Every 6th layer is a global layer, which is the full context size that has already been set. The other + // layers are the smaller local (sliding) layers. + if (i+1)%gemma3GlobalCacheCount != 0 { + kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + } + } + } case "command-r": fullOffload = max( 4*batch*(embedding+vocab), diff --git a/llm/memory.go b/llm/memory.go index 86694d06d..85a0fabd3 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -15,12 +15,12 @@ import ( ) // This algorithm looks for a complete fit to determine if we need to unload other models -func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) { +func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) { // Split up the GPUs by type and try them var estimatedVRAM uint64 for _, gpus := range allGpus.ByLibrary() { var layerCount int - estimate := EstimateGPULayers(gpus, f, projectors, opts) + estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel) layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize if opts.NumGPU < 0 { if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) { @@ -71,7 +71,7 @@ type MemoryEstimate struct { // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size // The GPUs provided must all be the same Library -func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate { +func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate { // Graph size for a partial offload, applies to all GPUs var graphPartialOffload uint64 @@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } } - kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct) + kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct) - // KV is proportional to the number of layers - layerSize += kv / f.KV().BlockCount() + if len(kv) > 0 { + layerSize += kv[0] + } + + var kvTotal uint64 + for _, kvLayer := range kv { + kvTotal += kvLayer + } if graphPartialOffload == 0 { - graphPartialOffload = f.KV().GQA() * kv / 6 + graphPartialOffload = f.KV().GQA() * kvTotal / 6 } if graphFullOffload == 0 { graphFullOffload = graphPartialOffload @@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin // Some models have inconsistent layer sizes if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { layerSize = blk.Size() - layerSize += kv / f.KV().BlockCount() + layerSize += kv[i] memoryWeights += blk.Size() } @@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin layersRequested: opts.NumGPU, layersModel: int(f.KV().BlockCount()) + 1, availableList: availableList, - kv: kv, + kv: kvTotal, allocationsList: allocationsList, memoryWeights: memoryWeights, memoryLayerOutput: memoryLayerOutput, diff --git a/llm/memory_test.go b/llm/memory_test.go index 40cc01dff..213784a02 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) { projectors := []string{} opts := api.DefaultOptions() t.Run("cpu", func(t *testing.T) { - estimate := EstimateGPULayers(gpus, ggml, projectors, opts) + estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1) assert.Equal(t, 0, estimate.Layers) assert.Equal(t, uint64(0), estimate.Graph) }) @@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) { gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1 gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload) gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload) - estimate := EstimateGPULayers(gpus, ggml, projectors, opts) + estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1) assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s) assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s) var layerSums uint64 diff --git a/llm/server.go b/llm/server.go index adc11aaea..e6046db60 100644 --- a/llm/server.go +++ b/llm/server.go @@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a gpus = discover.GetCPUInfo() } - estimate := EstimateGPULayers(gpus, f, projectors, opts) + estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel) if len(gpus) > 1 || gpus[0].Library != "cpu" { switch { case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory: diff --git a/server/sched.go b/server/sched.go index b4600dbf7..9126c2969 100644 --- a/server/sched.go +++ b/server/sched.go @@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn req.opts.NumCtx = req.origNumCtx * p if !envconfig.SchedSpread() { for _, g := range sgl { - if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok { slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) *numParallel = p return []discover.GpuInfo{g} @@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn // Now try all the GPUs for _, p := range numParallelToTry { req.opts.NumCtx = req.origNumCtx * p - if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok { slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM)) *numParallel = p return sgl @@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp var bestEstimate uint64 var bestFit int for i, gl := range byLibrary { - _, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) + _, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel) if estimatedVRAM > bestEstimate { bestEstimate = estimatedVRAM bestFit = i @@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) { // If not, pick a runner to unload, else return nil and the request can be loaded func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef { slog.Debug("evaluating if CPU model load will fit in available system memory") - estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts) + estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx) if estimate.TotalSize <= gpus[0].FreeMemory { slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory)) return nil