clean up vision model forward pass

This commit is contained in:
Bruce MacDonald 2025-04-29 16:54:48 -07:00
parent fcfad744ff
commit eed0ac2948
3 changed files with 64 additions and 83 deletions

View File

@ -14,14 +14,15 @@ type qwen25VLModel struct {
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
InChannels uint32 `json:"in_chans"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
@ -39,13 +40,15 @@ func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv["qwen25vl.vision.block_count"] = q.VisionModel.Depth
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.feed_forward_length"] = q.VisionModel.IntermediateSize
kv["qwen25vl.vision.attention.head_count"] = q.VisionModel.NumHeads
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = q.VisionModel.PatchSize
kv["qwen25vl.vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = q.VisionModel.WindowSize
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5)
kv["qwen25vl.vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
return kv
}

View File

@ -118,6 +118,4 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func init() {
model.Register("qwen25vl", New)
model.Register("qwen2", New)
model.Register("qwen2vl", New)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml/nn"
)
// We only support batch size of 1
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
@ -20,7 +21,6 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
// VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
@ -28,7 +28,6 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"`
}
// Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
@ -50,16 +49,15 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the MLP for the Qwen vision model
// VisionMLP implements the multi-layer perceptron
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
// Forward computes the MLP for the vision model
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using GEGLU activation: (Gate * Up) * GELU(Gate)
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
@ -67,7 +65,6 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Visi
return mlp.Down.Forward(ctx, hiddenStates)
}
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
@ -75,7 +72,6 @@ type VisionEncoderLayer struct {
MLP *VisionMLP
}
// Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
@ -88,21 +84,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.T
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options for the Qwen vision model
// VisionModelOptions contains configuration options
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeTheta float32
outHiddenSize int
spatialMergeSize int
spatialPatchSize int
windowSize int
temporalPatchSize int
}
type PatchEmbedding struct {
@ -110,23 +103,22 @@ type PatchEmbedding struct {
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor {
temporalPatchSize := 2 // we have two temporal convolutions
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into two parts for the two temporal convolutions
// Split the tensor into parts for the temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := patchSize, patchSize // Use full stride
s0, s1 := opts.patchSize, opts.patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
@ -136,7 +128,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChan
out := out0.Add(ctx, out1)
// Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, embedDim, numPatches)
return out.Reshape(ctx, opts.hiddenSize, numPatches)
}
// VisionPatchMerger implements patch merging for the Qwen vision model
@ -147,17 +139,16 @@ type VisionPatchMerger struct {
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
spatialMergeSize := 2 // This should come from config?
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension
// Similar to .view(-1, self.hidden_size) in PyTorch
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
output := pm.MLP2.Forward(ctx, activated)
return output
@ -175,13 +166,7 @@ type VisionModel struct {
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(
ctx,
pixelValues, // processed image tensor
m.numChannels, // number of channels, e.g., 3 for RGB
m.hiddenSize, // embedding size
m.patchSize, // patch size, e.g., 14
)
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.positionalEmbedding(ctx, grid)
@ -207,7 +192,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
}
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
@ -250,18 +235,13 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
}
// positionalEmbedding generates rotary position embeddings for attention mechanisms
// This implements rotary embeddings using spatial merging patterns for grid-based
// vision transformers
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
// Configuration parameters
dim := 80 / 2 // Head dimension divided by 2
freq := dim / 2 // Frequency dimension (half of head dimension)
theta := 10000.0 // Base for frequency scaling
merge := 2 // Spatial merge size for rearranging coordinates
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
// These are scaled position values based on frequency
// In PyTorch: Similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
@ -288,7 +268,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
}
// Reshape and permute positions to match spatial merging pattern
// This rearranges positions to group spatially related coordinates
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
@ -305,9 +284,13 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
ropeTheta := c.Float("vision.rope.freq_base", 10000.0) // not set
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
@ -315,16 +298,13 @@ func newVisionModel(c fs.Config) *VisionModel {
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)),
imageSize: int(c.Uint("vision.image_size", 560)),
patchSize: patchSize,
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
numChannels: numChannels,
eps: eps,
ropeTheta: ropeTheta,
outHiddenSize: outHiddenSize,
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
windowSize: int(c.Uint("vision.window_size", 112)),
spatialMergeSize: spatialMergeSize,
windowSize: windowSize,
temporalPatchSize: temporalPatchSize,
},
}
}