patch embeddings
This commit is contained in:
parent
c1f9bcb4dd
commit
1704072746
@ -1,7 +1,6 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@ -122,45 +121,32 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi
|
||||
shape := pixelValues.Shape()
|
||||
numChannels := 3
|
||||
temporalPatchSize := 2
|
||||
numPatches := shape[1] // TODO: check this
|
||||
// patch size from args
|
||||
embedDim := 1280
|
||||
numPatches := shape[1] / temporalPatchSize
|
||||
|
||||
// Split the input tensor into two temporal slices and process each separately
|
||||
// First temporal slice (frame 0)
|
||||
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
|
||||
reshaped0 := slice0.Reshape(ctx,
|
||||
patchSize, // height
|
||||
patchSize, // width
|
||||
numChannels, // channels
|
||||
numPatches) // batch
|
||||
// Reshape the input tensor to match the expected dimensions
|
||||
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
|
||||
|
||||
// Second temporal slice (frame 1)
|
||||
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
|
||||
reshaped1 := slice1.Reshape(ctx,
|
||||
patchSize, // height
|
||||
patchSize, // width
|
||||
numChannels, // channels
|
||||
numPatches) // batch
|
||||
// Permute the tensor to bring the temporal dimension to the front
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Apply the appropriate convolution to each temporal slice
|
||||
// PatchConv0 corresponds to weights for temporal frame 0
|
||||
// PatchConv1 corresponds to weights for temporal frame 1
|
||||
s0, s1 := patchSize, patchSize // Use full stride as in original
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
|
||||
|
||||
s0, s1 := patchSize, patchSize // Use full stride
|
||||
p0, p1 := 0, 0 // padding
|
||||
d0, d1 := 1, 1 // dilation
|
||||
|
||||
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
|
||||
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
|
||||
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Add the outputs from the two temporal convolutions
|
||||
combined := output0.Add(ctx, output1)
|
||||
out := out0.Add(ctx, out1)
|
||||
|
||||
// Reshape to required output dimensions
|
||||
result := combined.Reshape(ctx, embedDim, numPatches)
|
||||
|
||||
fmt.Println(ml.Dump(ctx, result))
|
||||
|
||||
return result
|
||||
// Reshape the output tensor to match the expected dimensions
|
||||
return out.Reshape(ctx, embedDim, numPatches)
|
||||
}
|
||||
|
||||
// VisionPatchMerger implements patch merging for the Qwen vision model
|
||||
|
Loading…
x
Reference in New Issue
Block a user