diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 762ec614e..93984fec9 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -117,13 +117,9 @@ type PatchEmbedding struct { PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` } -func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor { - shape := pixelValues.Shape() - numChannels := 3 - temporalPatchSize := 2 - numPatches := shape[1] // TODO: check this - // patch size from args - embedDim := 1280 +func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor { + temporalPatchSize := 2 // we have two temporal convolutions + numPatches := pixelValues.Shape()[1] // Reshape the input tensor to match the expected dimensions pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches) @@ -131,6 +127,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi // Permute the tensor to bring the temporal dimension to the front pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + // Split the tensor into two parts for the two temporal convolutions in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) @@ -187,7 +184,13 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { numPatches := numPatchesH * numPatchesW // Extract patch embeddings - hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize) + hiddenStates := m.PatchEmbedding.Forward( + ctx, + pixelValues, // processed image tensor + m.numChannels, // number of channels, e.g., 3 for RGB + m.hiddenSize, // embedding size + m.patchSize, // patch size, e.g., 14 + ) // Create position IDs - for Qwen2VL mRoPE we need 4 values per position // The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"