From eed0ac2948035f7a3f3f9f7786b0801fb53304e6 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 29 Apr 2025 16:54:48 -0700 Subject: [PATCH] clean up vision model forward pass --- convert/convert_qwen25vl.go | 25 +++--- model/models/qwen25vl/model.go | 2 - model/models/qwen25vl/model_vision.go | 120 +++++++++++--------------- 3 files changed, 64 insertions(+), 83 deletions(-) diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index 2008f2d16..412520b0d 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -12,16 +12,17 @@ type qwen25VLModel struct { qwen2Model VisionModel struct { - Depth uint32 `json:"depth"` - HiddenSize uint32 `json:"hidden_size"` - IntermediateSize uint32 `json:"intermediate_size"` - InChannels uint32 `json:"in_chans"` - NumHeads uint32 `json:"num_heads"` - PatchSize uint32 `json:"patch_size"` - SpatialMergeSize uint32 `json:"spatial_merge_size"` - SpatialPatchSize uint32 `json:"spatial_patch_size"` - WindowSize uint32 `json:"window_size"` - RopeTheta float32 `json:"rope_theta"` + Depth uint32 `json:"depth"` + HiddenSize uint32 `json:"hidden_size"` + NumHeads uint32 `json:"num_heads"` + InChannels uint32 `json:"in_chans"` + PatchSize uint32 `json:"patch_size"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + SpatialPatchSize uint32 `json:"spatial_patch_size"` + WindowSize uint32 `json:"window_size"` + RMSNormEps float32 `json:"layer_norm_epsilon"` + RopeTheta float32 `json:"rope_theta"` + TemporalPatchSize uint32 `json:"temporal_patch_size"` } `json:"vision_config"` } @@ -39,13 +40,15 @@ func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV { kv["qwen25vl.vision.block_count"] = q.VisionModel.Depth kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize - kv["qwen25vl.vision.feed_forward_length"] = q.VisionModel.IntermediateSize kv["qwen25vl.vision.attention.head_count"] = q.VisionModel.NumHeads kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels kv["qwen25vl.vision.patch_size"] = q.VisionModel.PatchSize kv["qwen25vl.vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize + kv["qwen25vl.vision.window_size"] = q.VisionModel.WindowSize + kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6) kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5) + kv["qwen25vl.vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize return kv } diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 5812b08eb..2d938b707 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -118,6 +118,4 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func init() { model.Register("qwen25vl", New) - model.Register("qwen2", New) - model.Register("qwen2vl", New) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index b69ab143f..4d9cb48a9 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/ml/nn" ) +// We only support batch size of 1 var batchSize int = 1 func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { @@ -20,7 +21,6 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) } -// VisionSelfAttention implements self-attention for the Qwen vision model type VisionSelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -28,7 +28,6 @@ type VisionSelfAttention struct { Output *nn.Linear `gguf:"attn_out"` } -// Forward computes self-attention for the vision model func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates) @@ -50,16 +49,15 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml return sa.Output.Forward(ctx, attention) } -// VisionMLP implements the MLP for the Qwen vision model +// VisionMLP implements the multi-layer perceptron type VisionMLP struct { Gate *nn.Linear `gguf:"ffn_gate"` Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` } -// Forward computes the MLP for the vision model func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - // Using GEGLU activation: (Gate * Up) * GELU(Gate) + // Using activation as specified in config (likely GELU or SiLU/Swish) gateOutput := mlp.Gate.Forward(ctx, hiddenStates) upOutput := mlp.Up.Forward(ctx, hiddenStates) hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput) @@ -67,7 +65,6 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Visi return mlp.Down.Forward(ctx, hiddenStates) } -// VisionEncoderLayer implements an encoder layer for the Qwen vision model type VisionEncoderLayer struct { Norm1 *nn.RMSNorm `gguf:"ln1"` SelfAttention *VisionSelfAttention @@ -75,7 +72,6 @@ type VisionEncoderLayer struct { MLP *VisionMLP } -// Forward computes an encoder layer for the vision model func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenStates hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) @@ -88,21 +84,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.T return hiddenStates.Add(ctx, residual) } -// VisionModelOptions contains configuration options for the Qwen vision model +// VisionModelOptions contains configuration options type VisionModelOptions struct { - hiddenSize int - numHeads int - headDim int - intermediateSize int - imageSize int - patchSize int - numChannels int - eps float32 - ropeTheta float32 - outHiddenSize int - spatialMergeSize int - spatialPatchSize int - windowSize int + hiddenSize int + numHeads int + headDim int + patchSize int + numChannels int + eps float32 + ropeTheta float32 + spatialMergeSize int + windowSize int + temporalPatchSize int } type PatchEmbedding struct { @@ -110,25 +103,24 @@ type PatchEmbedding struct { PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` } -func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor { - temporalPatchSize := 2 // we have two temporal convolutions +func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor { numPatches := pixelValues.Shape()[1] // Reshape the input tensor to match the expected dimensions - pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches) + pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches) // Permute the tensor to bring the temporal dimension to the front pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - // Split the tensor into two parts for the two temporal convolutions + // Split the tensor into parts for the temporal convolutions in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) - in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) + in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches) in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) - in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) + in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches) - s0, s1 := patchSize, patchSize // Use full stride - p0, p1 := 0, 0 // padding - d0, d1 := 1, 1 // dilation + s0, s1 := opts.patchSize, opts.patchSize // Use full stride + p0, p1 := 0, 0 // padding + d0, d1 := 1, 1 // dilation out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1) out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1) @@ -136,7 +128,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChan out := out0.Add(ctx, out1) // Reshape the output tensor to match the expected dimensions - return out.Reshape(ctx, embedDim, numPatches) + return out.Reshape(ctx, opts.hiddenSize, numPatches) } // VisionPatchMerger implements patch merging for the Qwen vision model @@ -147,17 +139,16 @@ type VisionPatchMerger struct { } // Forward computes patch merging for the vision model -func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { - normalized := pm.LNQ.Forward(ctx, visionOutputs, eps) +func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps) - spatialMergeSize := 2 // This should come from config? - hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize) + hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize) // Reshape the normalized output to view the hidden size dimension - // Similar to .view(-1, self.hidden_size) in PyTorch - reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize) + reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize) hidden := pm.MLP0.Forward(ctx, reshaped) activated := hidden.GELU(ctx) + output := pm.MLP2.Forward(ctx, activated) return output @@ -175,13 +166,7 @@ type VisionModel struct { // Forward computes the vision model for an input tensor func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor { // Extract patch embeddings - hiddenStates := m.PatchEmbedding.Forward( - ctx, - pixelValues, // processed image tensor - m.numChannels, // number of channels, e.g., 3 for RGB - m.hiddenSize, // embedding size - m.patchSize, // patch size, e.g., 14 - ) + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) positionEmbedding := m.positionalEmbedding(ctx, grid) @@ -207,7 +192,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions) } - return m.PatchMerger.Forward(ctx, hiddenStates, m.eps) + return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) } func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { @@ -250,18 +235,13 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { } // positionalEmbedding generates rotary position embeddings for attention mechanisms -// This implements rotary embeddings using spatial merging patterns for grid-based -// vision transformers func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { - // Configuration parameters - dim := 80 / 2 // Head dimension divided by 2 - freq := dim / 2 // Frequency dimension (half of head dimension) - theta := 10000.0 // Base for frequency scaling - merge := 2 // Spatial merge size for rearranging coordinates + dim := m.headDim / 2 + freq := dim / 2 + theta := float64(m.ropeTheta) + merge := m.spatialMergeSize // Create frequency patterns for position encoding - // These are scaled position values based on frequency - // In PyTorch: Similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) maxGridSize := max(grid.Height, grid.Width) freqVals := make([]float32, freq*maxGridSize) for i := range maxGridSize { @@ -288,7 +268,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor } // Reshape and permute positions to match spatial merging pattern - // This rearranges positions to group spatially related coordinates pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge) @@ -305,26 +284,27 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor func newVisionModel(c fs.Config) *VisionModel { patchSize := int(c.Uint("vision.patch_size", 14)) hiddenSize := int(c.Uint("vision.embedding_length", 1280)) - ropeTheta := c.Float("vision.rope.freq_base", 10000.0) // not set - outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set numHeads := int(c.Uint("vision.attention.head_count", 16)) + numChannels := int(c.Uint("vision.num_channels", 3)) + eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6) + ropeTheta := c.Float("vision.rope.freq_base", 10000.0) + spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2)) + windowSize := int(c.Uint("vision.window_size", 112)) + temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2)) return &VisionModel{ Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)), VisionModelOptions: &VisionModelOptions{ - hiddenSize: hiddenSize, - numHeads: numHeads, - headDim: hiddenSize / numHeads, - intermediateSize: int(c.Uint("vision.feed_forward_length", 0)), - imageSize: int(c.Uint("vision.image_size", 560)), - patchSize: patchSize, - numChannels: int(c.Uint("vision.num_channels", 3)), // not set - eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), - ropeTheta: ropeTheta, - outHiddenSize: outHiddenSize, - spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)), - spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)), - windowSize: int(c.Uint("vision.window_size", 112)), + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: hiddenSize / numHeads, + patchSize: patchSize, + numChannels: numChannels, + eps: eps, + ropeTheta: ropeTheta, + spatialMergeSize: spatialMergeSize, + windowSize: windowSize, + temporalPatchSize: temporalPatchSize, }, } }