reshape cos and sin
This commit is contained in:
parent
04936b719f
commit
8d901825f0
@ -179,6 +179,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
// Apply encoder layers
|
||||
for _, layer := range m.Layers {
|
||||
|
Loading…
x
Reference in New Issue
Block a user