fix patch merger

This commit is contained in:
Bruce MacDonald 2025-04-29 15:43:10 -07:00
parent fb3c16f2a2
commit fcfad744ff

View File

@ -148,19 +148,16 @@ type VisionPatchMerger struct {
// Forward computes patch merging for the vision model // Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Get dimensions normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
hiddenSize := visionOutputs.Dim(0)
numPositions := visionOutputs.Dim(1)
batchSize := visionOutputs.Dim(2)
reshaped := pm.LNQ.Forward(ctx, visionOutputs, 1e6).Reshape(ctx, hiddenSize*4, numPositions/4, batchSize) spatialMergeSize := 2 // This should come from config?
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
// Apply first linear layer (mm_0_w, mm_0_b) // Reshape the normalized output to view the hidden size dimension
// Similar to .view(-1, self.hidden_size) in PyTorch
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped) hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx) activated := hidden.GELU(ctx)
// Apply second linear layer (mm_1_w, mm_1_b)
output := pm.MLP2.Forward(ctx, activated) output := pm.MLP2.Forward(ctx, activated)
return output return output