clean up vision model forward pass

This commit is contained in:
Bruce MacDonald 2025-04-29 16:54:48 -07:00
parent fcfad744ff
commit eed0ac2948
3 changed files with 64 additions and 83 deletions

View File

@ -12,16 +12,17 @@ type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
InChannels uint32 `json:"in_chans"`
NumHeads uint32 `json:"num_heads"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RopeTheta float32 `json:"rope_theta"`
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
@ -39,13 +40,15 @@ func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv["qwen25vl.vision.block_count"] = q.VisionModel.Depth
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.feed_forward_length"] = q.VisionModel.IntermediateSize
kv["qwen25vl.vision.attention.head_count"] = q.VisionModel.NumHeads
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = q.VisionModel.PatchSize
kv["qwen25vl.vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = q.VisionModel.WindowSize
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5)
kv["qwen25vl.vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
return kv
}

View File

@ -118,6 +118,4 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func init() {
model.Register("qwen25vl", New)
model.Register("qwen2", New)
model.Register("qwen2vl", New)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml/nn"
)
// We only support batch size of 1
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
@ -20,7 +21,6 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
// VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
@ -28,7 +28,6 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"`
}
// Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
@ -50,16 +49,15 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the MLP for the Qwen vision model
// VisionMLP implements the multi-layer perceptron
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
// Forward computes the MLP for the vision model
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using GEGLU activation: (Gate * Up) * GELU(Gate)
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
@ -67,7 +65,6 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Visi
return mlp.Down.Forward(ctx, hiddenStates)
}
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
@ -75,7 +72,6 @@ type VisionEncoderLayer struct {
MLP *VisionMLP
}
// Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
@ -88,21 +84,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.T
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options for the Qwen vision model
// VisionModelOptions contains configuration options
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeTheta float32
outHiddenSize int
spatialMergeSize int
spatialPatchSize int
windowSize int
hiddenSize int
numHeads int
headDim int
patchSize int
numChannels int
eps float32
ropeTheta float32
spatialMergeSize int
windowSize int
temporalPatchSize int
}
type PatchEmbedding struct {
@ -110,25 +103,24 @@ type PatchEmbedding struct {
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor {
temporalPatchSize := 2 // we have two temporal convolutions
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches)
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into two parts for the two temporal convolutions
// Split the tensor into parts for the temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := patchSize, patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
s0, s1 := opts.patchSize, opts.patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
@ -136,7 +128,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChan
out := out0.Add(ctx, out1)
// Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, embedDim, numPatches)
return out.Reshape(ctx, opts.hiddenSize, numPatches)
}
// VisionPatchMerger implements patch merging for the Qwen vision model
@ -147,17 +139,16 @@ type VisionPatchMerger struct {
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps)
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
spatialMergeSize := 2 // This should come from config?
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension
// Similar to .view(-1, self.hidden_size) in PyTorch
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
output := pm.MLP2.Forward(ctx, activated)
return output
@ -175,13 +166,7 @@ type VisionModel struct {
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(
ctx,
pixelValues, // processed image tensor
m.numChannels, // number of channels, e.g., 3 for RGB
m.hiddenSize, // embedding size
m.patchSize, // patch size, e.g., 14
)
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.positionalEmbedding(ctx, grid)
@ -207,7 +192,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
}
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
@ -250,18 +235,13 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
}
// positionalEmbedding generates rotary position embeddings for attention mechanisms
// This implements rotary embeddings using spatial merging patterns for grid-based
// vision transformers
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
// Configuration parameters
dim := 80 / 2 // Head dimension divided by 2
freq := dim / 2 // Frequency dimension (half of head dimension)
theta := 10000.0 // Base for frequency scaling
merge := 2 // Spatial merge size for rearranging coordinates
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
// These are scaled position values based on frequency
// In PyTorch: Similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
@ -288,7 +268,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
}
// Reshape and permute positions to match spatial merging pattern
// This rearranges positions to group spatially related coordinates
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
@ -305,26 +284,27 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
ropeTheta := c.Float("vision.rope.freq_base", 10000.0) // not set
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)),
imageSize: int(c.Uint("vision.image_size", 560)),
patchSize: patchSize,
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: ropeTheta,
outHiddenSize: outHiddenSize,
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
windowSize: int(c.Uint("vision.window_size", 112)),
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
patchSize: patchSize,
numChannels: numChannels,
eps: eps,
ropeTheta: ropeTheta,
spatialMergeSize: spatialMergeSize,
windowSize: windowSize,
temporalPatchSize: temporalPatchSize,
},
}
}