get patch embedding vals from config

This commit is contained in:
Bruce MacDonald 2025-04-21 12:04:46 -07:00
parent 1704072746
commit 3fe090f447

View File

@ -117,13 +117,9 @@ type PatchEmbedding struct {
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
shape := pixelValues.Shape()
numChannels := 3
temporalPatchSize := 2
numPatches := shape[1] // TODO: check this
// patch size from args
embedDim := 1280
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor {
temporalPatchSize := 2 // we have two temporal convolutions
numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
@ -131,6 +127,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into two parts for the two temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
@ -187,7 +184,13 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := numPatchesH * numPatchesW
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
hiddenStates := m.PatchEmbedding.Forward(
ctx,
pixelValues, // processed image tensor
m.numChannels, // number of channels, e.g., 3 for RGB
m.hiddenSize, // embedding size
m.patchSize, // patch size, e.g., 14
)
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"