fix patch merger
This commit is contained in:
parent
fb3c16f2a2
commit
fcfad744ff
@ -148,19 +148,16 @@ type VisionPatchMerger struct {
|
|||||||
|
|
||||||
// Forward computes patch merging for the vision model
|
// Forward computes patch merging for the vision model
|
||||||
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||||
// Get dimensions
|
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
|
||||||
hiddenSize := visionOutputs.Dim(0)
|
|
||||||
numPositions := visionOutputs.Dim(1)
|
|
||||||
batchSize := visionOutputs.Dim(2)
|
|
||||||
|
|
||||||
reshaped := pm.LNQ.Forward(ctx, visionOutputs, 1e6).Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
|
spatialMergeSize := 2 // This should come from config?
|
||||||
|
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
|
||||||
|
|
||||||
// Apply first linear layer (mm_0_w, mm_0_b)
|
// Reshape the normalized output to view the hidden size dimension
|
||||||
|
// Similar to .view(-1, self.hidden_size) in PyTorch
|
||||||
|
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
|
||||||
hidden := pm.MLP0.Forward(ctx, reshaped)
|
hidden := pm.MLP0.Forward(ctx, reshaped)
|
||||||
|
|
||||||
activated := hidden.GELU(ctx)
|
activated := hidden.GELU(ctx)
|
||||||
|
|
||||||
// Apply second linear layer (mm_1_w, mm_1_b)
|
|
||||||
output := pm.MLP2.Forward(ctx, activated)
|
output := pm.MLP2.Forward(ctx, activated)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user