patch embeddings
This commit is contained in:
parent
c1f9bcb4dd
commit
1704072746
@ -1,7 +1,6 @@
|
|||||||
package qwen25vl
|
package qwen25vl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
@ -122,45 +121,32 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi
|
|||||||
shape := pixelValues.Shape()
|
shape := pixelValues.Shape()
|
||||||
numChannels := 3
|
numChannels := 3
|
||||||
temporalPatchSize := 2
|
temporalPatchSize := 2
|
||||||
|
numPatches := shape[1] // TODO: check this
|
||||||
|
// patch size from args
|
||||||
embedDim := 1280
|
embedDim := 1280
|
||||||
numPatches := shape[1] / temporalPatchSize
|
|
||||||
|
|
||||||
// Split the input tensor into two temporal slices and process each separately
|
// Reshape the input tensor to match the expected dimensions
|
||||||
// First temporal slice (frame 0)
|
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
|
||||||
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
|
|
||||||
reshaped0 := slice0.Reshape(ctx,
|
|
||||||
patchSize, // height
|
|
||||||
patchSize, // width
|
|
||||||
numChannels, // channels
|
|
||||||
numPatches) // batch
|
|
||||||
|
|
||||||
// Second temporal slice (frame 1)
|
// Permute the tensor to bring the temporal dimension to the front
|
||||||
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
|
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
reshaped1 := slice1.Reshape(ctx,
|
|
||||||
patchSize, // height
|
|
||||||
patchSize, // width
|
|
||||||
numChannels, // channels
|
|
||||||
numPatches) // batch
|
|
||||||
|
|
||||||
// Apply the appropriate convolution to each temporal slice
|
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||||
// PatchConv0 corresponds to weights for temporal frame 0
|
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
|
||||||
// PatchConv1 corresponds to weights for temporal frame 1
|
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||||
s0, s1 := patchSize, patchSize // Use full stride as in original
|
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
|
||||||
|
|
||||||
|
s0, s1 := patchSize, patchSize // Use full stride
|
||||||
p0, p1 := 0, 0 // padding
|
p0, p1 := 0, 0 // padding
|
||||||
d0, d1 := 1, 1 // dilation
|
d0, d1 := 1, 1 // dilation
|
||||||
|
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||||
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
|
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||||
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
|
|
||||||
|
|
||||||
// Add the outputs from the two temporal convolutions
|
// Add the outputs from the two temporal convolutions
|
||||||
combined := output0.Add(ctx, output1)
|
out := out0.Add(ctx, out1)
|
||||||
|
|
||||||
// Reshape to required output dimensions
|
// Reshape the output tensor to match the expected dimensions
|
||||||
result := combined.Reshape(ctx, embedDim, numPatches)
|
return out.Reshape(ctx, embedDim, numPatches)
|
||||||
|
|
||||||
fmt.Println(ml.Dump(ctx, result))
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VisionPatchMerger implements patch merging for the Qwen vision model
|
// VisionPatchMerger implements patch merging for the Qwen vision model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user