move mask
This commit is contained in:
parent
57279f89a2
commit
1a2c413225
@ -232,6 +232,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
||||
// Apply encoder layers
|
||||
for i, layer := range m.Layers {
|
||||
if slices.Contains(m.fullAttnBlocks, i) {
|
||||
@ -242,7 +243,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
hiddenStates,
|
||||
cos,
|
||||
sin,
|
||||
blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads),
|
||||
mask,
|
||||
m.VisionModelOptions,
|
||||
)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user