From f0ad49ea17d587cce7f4b2c6a6ccb3139ec083c8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 23 Apr 2025 16:20:40 -0700 Subject: [PATCH] memory --- fs/ggml/ggml.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 947295e36..427a43aec 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -430,7 +430,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri } switch f.KV().Architecture() { - case "llama": + case "llama", "llama4": fullOffload = max( 4*batch*(1+4*embedding+context*(1+heads)), 4*batch*(embedding+vocab), @@ -444,7 +444,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { // mixtral 8x22b - ff := uint64(f.KV()["llama.feed_forward_length"].(uint32)) + ff := uint64(f.KV().Uint("feed_forward_length")) partialOffload = max( 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), @@ -640,6 +640,9 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { graphSize = 4 * (imageSize*imageSize*numChannels + embeddingLength*patchSize + numPatches*numPatches*headCount) + case "llama4": + // vision graph is computed independently in the same schedule + // and is negligible compared to the worst case text graph } return weights, graphSize