diff --git a/kvcache/causal.go b/kvcache/causal.go index aacaf540f..fb4f0f743 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity } var cacheSize int - if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch { + if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) { cacheSize = maxSequences * capacity } else { - cacheSize = maxSequences * (int(c.windowSize) + maxBatch) + cacheSize = (maxSequences * int(c.windowSize)) + maxBatch } cacheSize = roundUp(cacheSize, c.config.CachePadding) c.cells = make([]cacheCell, cacheSize)