fix patch merger
This commit is contained in:
parent
fb3c16f2a2
commit
fcfad744ff
@ -148,19 +148,16 @@ type VisionPatchMerger struct {
|
||||
|
||||
// Forward computes patch merging for the vision model
|
||||
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
// Get dimensions
|
||||
hiddenSize := visionOutputs.Dim(0)
|
||||
numPositions := visionOutputs.Dim(1)
|
||||
batchSize := visionOutputs.Dim(2)
|
||||
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
reshaped := pm.LNQ.Forward(ctx, visionOutputs, 1e6).Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
|
||||
spatialMergeSize := 2 // This should come from config?
|
||||
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
|
||||
|
||||
// Apply first linear layer (mm_0_w, mm_0_b)
|
||||
// Reshape the normalized output to view the hidden size dimension
|
||||
// Similar to .view(-1, self.hidden_size) in PyTorch
|
||||
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
|
||||
hidden := pm.MLP0.Forward(ctx, reshaped)
|
||||
|
||||
activated := hidden.GELU(ctx)
|
||||
|
||||
// Apply second linear layer (mm_1_w, mm_1_b)
|
||||
output := pm.MLP2.Forward(ctx, activated)
|
||||
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user