move mask

This commit is contained in:
Bruce MacDonald 2025-05-09 14:35:10 -07:00
parent 57279f89a2
commit 1a2c413225

View File

@ -232,6 +232,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
// Apply encoder layers
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, i) {
@ -242,7 +243,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
hiddenStates,
cos,
sin,
blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads),
mask,
m.VisionModelOptions,
)
}