get patch embedding vals from config
This commit is contained in:
parent
1704072746
commit
3fe090f447
@ -117,13 +117,9 @@ type PatchEmbedding struct {
|
||||
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
}
|
||||
|
||||
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
|
||||
shape := pixelValues.Shape()
|
||||
numChannels := 3
|
||||
temporalPatchSize := 2
|
||||
numPatches := shape[1] // TODO: check this
|
||||
// patch size from args
|
||||
embedDim := 1280
|
||||
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor {
|
||||
temporalPatchSize := 2 // we have two temporal convolutions
|
||||
numPatches := pixelValues.Shape()[1]
|
||||
|
||||
// Reshape the input tensor to match the expected dimensions
|
||||
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
|
||||
@ -131,6 +127,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi
|
||||
// Permute the tensor to bring the temporal dimension to the front
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Split the tensor into two parts for the two temporal convolutions
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
@ -187,7 +184,13 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatches := numPatchesH * numPatchesW
|
||||
|
||||
// Extract patch embeddings
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
|
||||
hiddenStates := m.PatchEmbedding.Forward(
|
||||
ctx,
|
||||
pixelValues, // processed image tensor
|
||||
m.numChannels, // number of channels, e.g., 3 for RGB
|
||||
m.hiddenSize, // embedding size
|
||||
m.patchSize, // patch size, e.g., 14
|
||||
)
|
||||
|
||||
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
||||
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
|
||||
|
Loading…
x
Reference in New Issue
Block a user