calculate block mask once, rather than in attention

This commit is contained in:
Bruce MacDonald 2025-05-09 13:58:47 -07:00
parent 9ceee25d8b
commit 57279f89a2

View File

@ -61,7 +61,7 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor {
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
@ -75,10 +75,6 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
var mask ml.Tensor
if bounds != nil {
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
}
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
@ -120,10 +116,10 @@ type VisionEncoderLayer struct {
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor {
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
@ -241,7 +237,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
if slices.Contains(m.fullAttnBlocks, i) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
} else {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
hiddenStates = layer.Forward(
ctx,
hiddenStates,
cos,
sin,
blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads),
m.VisionModelOptions,
)
}
}