patch embeddings

This commit is contained in:
Bruce MacDonald 2025-04-21 09:43:56 -07:00
parent c1f9bcb4dd
commit 1704072746

View File

@ -1,7 +1,6 @@
package qwen25vl
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
@ -122,45 +121,32 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi
shape := pixelValues.Shape()
numChannels := 3
temporalPatchSize := 2
numPatches := shape[1] // TODO: check this
// patch size from args
embedDim := 1280
numPatches := shape[1] / temporalPatchSize
// Split the input tensor into two temporal slices and process each separately
// First temporal slice (frame 0)
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
reshaped0 := slice0.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
// Second temporal slice (frame 1)
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
reshaped1 := slice1.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Apply the appropriate convolution to each temporal slice
// PatchConv0 corresponds to weights for temporal frame 0
// PatchConv1 corresponds to weights for temporal frame 1
s0, s1 := patchSize, patchSize // Use full stride as in original
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
s0, s1 := patchSize, patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
// Add the outputs from the two temporal convolutions
combined := output0.Add(ctx, output1)
out := out0.Add(ctx, out1)
// Reshape to required output dimensions
result := combined.Reshape(ctx, embedDim, numPatches)
fmt.Println(ml.Dump(ctx, result))
return result
// Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, embedDim, numPatches)
}
// VisionPatchMerger implements patch merging for the Qwen vision model