fix patch merger

This commit is contained in:
Bruce MacDonald 2025-04-29 15:43:10 -07:00
parent fb3c16f2a2
commit fcfad744ff

View File

@ -148,19 +148,16 @@ type VisionPatchMerger struct {
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Get dimensions
hiddenSize := visionOutputs.Dim(0)
numPositions := visionOutputs.Dim(1)
batchSize := visionOutputs.Dim(2)
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
reshaped := pm.LNQ.Forward(ctx, visionOutputs, 1e6).Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
spatialMergeSize := 2 // This should come from config?
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
// Apply first linear layer (mm_0_w, mm_0_b)
// Reshape the normalized output to view the hidden size dimension
// Similar to .view(-1, self.hidden_size) in PyTorch
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
// Apply second linear layer (mm_1_w, mm_1_b)
output := pm.MLP2.Forward(ctx, activated)
return output