From 57279f89a2b03ece0fbb1de751ea3dbc4274cf84 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 9 May 2025 13:58:47 -0700 Subject: [PATCH] calculate block mask once, rather than in attention --- model/models/qwen25vl/model_vision.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 1574c0bc7..485644963 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -61,7 +61,7 @@ type VisionSelfAttention struct { Output *nn.Linear `gguf:"attn_out"` } -func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor { +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates) @@ -75,10 +75,6 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) - var mask ml.Tensor - if bounds != nil { - mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads) - } // Scaled dot-product attention query = query.Permute(ctx, 0, 2, 1, 3) @@ -120,10 +116,10 @@ type VisionEncoderLayer struct { MLP *VisionMLP } -func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor { +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenStates hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts) hiddenStates = hiddenStates.Add(ctx, residual) residual = hiddenStates @@ -241,7 +237,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) if slices.Contains(m.fullAttnBlocks, i) { hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions) } else { - hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions) + hiddenStates = layer.Forward( + ctx, + hiddenStates, + cos, + sin, + blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads), + m.VisionModelOptions, + ) } }