calculate block mask once, rather than in attention
This commit is contained in:
parent
9ceee25d8b
commit
57279f89a2
@ -61,7 +61,7 @@ type VisionSelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_out"`
|
Output *nn.Linear `gguf:"attn_out"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor {
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
query := sa.Query.Forward(ctx, hiddenStates)
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
key := sa.Key.Forward(ctx, hiddenStates)
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
value := sa.Value.Forward(ctx, hiddenStates)
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
@ -75,10 +75,6 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||||||
|
|
||||||
// Scale factor for scaled dot-product attention
|
// Scale factor for scaled dot-product attention
|
||||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||||
var mask ml.Tensor
|
|
||||||
if bounds != nil {
|
|
||||||
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scaled dot-product attention
|
// Scaled dot-product attention
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
@ -120,10 +116,10 @@ type VisionEncoderLayer struct {
|
|||||||
MLP *VisionMLP
|
MLP *VisionMLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor {
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
residual := hiddenStates
|
residual := hiddenStates
|
||||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts)
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
|
||||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
residual = hiddenStates
|
residual = hiddenStates
|
||||||
@ -241,7 +237,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
if slices.Contains(m.fullAttnBlocks, i) {
|
if slices.Contains(m.fullAttnBlocks, i) {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
|
||||||
} else {
|
} else {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
|
hiddenStates = layer.Forward(
|
||||||
|
ctx,
|
||||||
|
hiddenStates,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads),
|
||||||
|
m.VisionModelOptions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user