reshape cos and sin

This commit is contained in:
Bruce MacDonald 2025-04-28 11:14:12 -07:00
parent 04936b719f
commit 8d901825f0

View File

@ -179,6 +179,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
positionEmbedding := m.positionalEmbedding(ctx, grid)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
// Apply encoder layers
for _, layer := range m.Layers {