diff --git a/kvcache/causal.go b/kvcache/causal.go index 34d5337cf..020298005 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e type Causal struct { DType ml.DType Capacity int32 - causal bool windowSize int32 + opts CausalOptions + // config controls mostly backend-specific optimizations config *ml.CacheConfig @@ -79,7 +80,6 @@ type cellRange struct { func NewCausalCache(shift shiftFn) *Causal { return &Causal{ - causal: true, windowSize: math.MaxInt32, shiftFn: shift, ctxs: make(map[int]ml.Context), @@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal { func NewSWACache(windowSize int32, shift shiftFn) *Causal { return &Causal{ - causal: true, windowSize: windowSize, shiftFn: shift, ctxs: make(map[int]ml.Context), @@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask := make([]float32, batchSize*length) for i := range c.curBatchSize { + enabled := !slices.Contains(c.opts.Except, c.curPositions[i]) for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || - (c.causal && c.cells[j].pos > c.curPositions[i]) || + (enabled && c.cells[j].pos > c.curPositions[i]) || c.cells[j].pos < c.curPositions[i]-c.windowSize { mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } @@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) { c.curLayer = layer } +type CausalOptions struct { + // Enabled controls whether the causal mask is generated for a particular position. + Except []int32 +} + // SetCausal enables or disables causal mask generation for subsequent calls to Get. // This state carries over to future forward passes. The default value is true. // // ctx may be set to nil if this is called from outside of a forward pass, for // example, when initializing the cache. -func (c *Causal) SetCausal(ctx ml.Context, causal bool) { - if c.causal != causal { - c.causal = causal - +func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { + if !slices.Equal(c.opts.Except, opts.Except) { + c.opts = opts if ctx != nil { var err error c.curMask, err = c.buildMask(ctx) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index bf7f6b4c0..de8070d91 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok { - causal.SetCausal(ctx, false) - defer causal.SetCausal(ctx, true) + except := make([]int32, visionOutputs.Dim(1)) + for i := 0; i < visionOutputs.Dim(1); i++ { + except[i] = int32(offset + i) + } + + causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) } }