move mask
This commit is contained in:
parent
57279f89a2
commit
1a2c413225
@ -232,6 +232,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||||
|
|
||||||
|
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
||||||
// Apply encoder layers
|
// Apply encoder layers
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
if slices.Contains(m.fullAttnBlocks, i) {
|
if slices.Contains(m.fullAttnBlocks, i) {
|
||||||
@ -242,7 +243,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
hiddenStates,
|
hiddenStates,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads),
|
mask,
|
||||||
m.VisionModelOptions,
|
m.VisionModelOptions,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user