diff --git a/kvcache/causal.go b/kvcache/causal.go index 79fa24e87..e5216d588 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -321,7 +321,8 @@ func (c *Causal) defrag() { ctx := c.backend.NewContext() // For every move, 6 tensors are required per layer (2 views and a - // copy for each of k and v). + // copy for each of k and v). We also need to refer to the original + // k and v cache tensors - once per layer, not per move. layers := 0 for _, key := range c.keys { if key == nil { @@ -330,7 +331,7 @@ func (c *Causal) defrag() { layers++ } - maxMoves := ctx.MaxGraphNodes() / (6 * layers) + maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers) moves := 0 var pendingSrc, pendingDst, pendingLen int