clean up vision model forward pass

This commit is contained in:
Bruce MacDonald 2025-04-29 16:54:48 -07:00
parent fcfad744ff
commit eed0ac2948
3 changed files with 64 additions and 83 deletions

View File

@ -12,16 +12,17 @@ type qwen25VLModel struct {
qwen2Model qwen2Model
VisionModel struct { VisionModel struct {
Depth uint32 `json:"depth"` Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"` HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"` NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"` InChannels uint32 `json:"in_chans"`
NumHeads uint32 `json:"num_heads"` PatchSize uint32 `json:"patch_size"`
PatchSize uint32 `json:"patch_size"` SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"` SpatialPatchSize uint32 `json:"spatial_patch_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"` WindowSize uint32 `json:"window_size"`
WindowSize uint32 `json:"window_size"` RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"` } `json:"vision_config"`
} }
@ -39,13 +40,15 @@ func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv["qwen25vl.vision.block_count"] = q.VisionModel.Depth kv["qwen25vl.vision.block_count"] = q.VisionModel.Depth
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.feed_forward_length"] = q.VisionModel.IntermediateSize
kv["qwen25vl.vision.attention.head_count"] = q.VisionModel.NumHeads kv["qwen25vl.vision.attention.head_count"] = q.VisionModel.NumHeads
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = q.VisionModel.PatchSize kv["qwen25vl.vision.patch_size"] = q.VisionModel.PatchSize
kv["qwen25vl.vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize kv["qwen25vl.vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = q.VisionModel.WindowSize
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5) kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5)
kv["qwen25vl.vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
return kv return kv
} }

View File

@ -118,6 +118,4 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func init() { func init() {
model.Register("qwen25vl", New) model.Register("qwen25vl", New)
model.Register("qwen2", New)
model.Register("qwen2vl", New)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
) )
// We only support batch size of 1
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
@ -20,7 +21,6 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
} }
// VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct { type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
@ -28,7 +28,6 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"` Output *nn.Linear `gguf:"attn_out"`
} }
// Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates) query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates)
@ -50,16 +49,15 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
// VisionMLP implements the MLP for the Qwen vision model // VisionMLP implements the multi-layer perceptron
type VisionMLP struct { type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"` Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"` Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"` Down *nn.Linear `gguf:"ffn_down"`
} }
// Forward computes the MLP for the vision model
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using GEGLU activation: (Gate * Up) * GELU(Gate) // Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates) gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates) upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput) hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
@ -67,7 +65,6 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Visi
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
type VisionEncoderLayer struct { type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"` Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention SelfAttention *VisionSelfAttention
@ -75,7 +72,6 @@ type VisionEncoderLayer struct {
MLP *VisionMLP MLP *VisionMLP
} }
// Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
@ -88,21 +84,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.T
return hiddenStates.Add(ctx, residual) return hiddenStates.Add(ctx, residual)
} }
// VisionModelOptions contains configuration options for the Qwen vision model // VisionModelOptions contains configuration options
type VisionModelOptions struct { type VisionModelOptions struct {
hiddenSize int hiddenSize int
numHeads int numHeads int
headDim int headDim int
intermediateSize int patchSize int
imageSize int numChannels int
patchSize int eps float32
numChannels int ropeTheta float32
eps float32 spatialMergeSize int
ropeTheta float32 windowSize int
outHiddenSize int temporalPatchSize int
spatialMergeSize int
spatialPatchSize int
windowSize int
} }
type PatchEmbedding struct { type PatchEmbedding struct {
@ -110,25 +103,24 @@ type PatchEmbedding struct {
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
} }
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChannels, embedDim, patchSize int) ml.Tensor { func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
temporalPatchSize := 2 // we have two temporal convolutions
numPatches := pixelValues.Shape()[1] numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions // Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches) pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute the tensor to bring the temporal dimension to the front // Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into two parts for the two temporal convolutions // Split the tensor into parts for the temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := patchSize, patchSize // Use full stride s0, s1 := opts.patchSize, opts.patchSize // Use full stride
p0, p1 := 0, 0 // padding p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation d0, d1 := 1, 1 // dilation
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1) out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1) out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
@ -136,7 +128,7 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, numChan
out := out0.Add(ctx, out1) out := out0.Add(ctx, out1)
// Reshape the output tensor to match the expected dimensions // Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, embedDim, numPatches) return out.Reshape(ctx, opts.hiddenSize, numPatches)
} }
// VisionPatchMerger implements patch merging for the Qwen vision model // VisionPatchMerger implements patch merging for the Qwen vision model
@ -147,17 +139,16 @@ type VisionPatchMerger struct {
} }
// Forward computes patch merging for the vision model // Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, eps) normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
spatialMergeSize := 2 // This should come from config? hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
hiddenSize := visionOutputs.Dim(0) * (spatialMergeSize * spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension // Reshape the normalized output to view the hidden size dimension
// Similar to .view(-1, self.hidden_size) in PyTorch reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(spatialMergeSize*spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped) hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx) activated := hidden.GELU(ctx)
output := pm.MLP2.Forward(ctx, activated) output := pm.MLP2.Forward(ctx, activated)
return output return output
@ -175,13 +166,7 @@ type VisionModel struct {
// Forward computes the vision model for an input tensor // Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor { func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings // Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward( hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
ctx,
pixelValues, // processed image tensor
m.numChannels, // number of channels, e.g., 3 for RGB
m.hiddenSize, // embedding size
m.patchSize, // patch size, e.g., 14
)
positionEmbedding := m.positionalEmbedding(ctx, grid) positionEmbedding := m.positionalEmbedding(ctx, grid)
@ -207,7 +192,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions) hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
} }
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps) return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
} }
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
@ -250,18 +235,13 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
} }
// positionalEmbedding generates rotary position embeddings for attention mechanisms // positionalEmbedding generates rotary position embeddings for attention mechanisms
// This implements rotary embeddings using spatial merging patterns for grid-based
// vision transformers
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
// Configuration parameters dim := m.headDim / 2
dim := 80 / 2 // Head dimension divided by 2 freq := dim / 2
freq := dim / 2 // Frequency dimension (half of head dimension) theta := float64(m.ropeTheta)
theta := 10000.0 // Base for frequency scaling merge := m.spatialMergeSize
merge := 2 // Spatial merge size for rearranging coordinates
// Create frequency patterns for position encoding // Create frequency patterns for position encoding
// These are scaled position values based on frequency
// In PyTorch: Similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
maxGridSize := max(grid.Height, grid.Width) maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize) freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize { for i := range maxGridSize {
@ -288,7 +268,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
} }
// Reshape and permute positions to match spatial merging pattern // Reshape and permute positions to match spatial merging pattern
// This rearranges positions to group spatially related coordinates
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge) pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
@ -305,26 +284,27 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
func newVisionModel(c fs.Config) *VisionModel { func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14)) patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280)) hiddenSize := int(c.Uint("vision.embedding_length", 1280))
ropeTheta := c.Float("vision.rope.freq_base", 10000.0) // not set
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
numHeads := int(c.Uint("vision.attention.head_count", 16)) numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
return &VisionModel{ return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)), Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionModelOptions: &VisionModelOptions{ VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize, hiddenSize: hiddenSize,
numHeads: numHeads, numHeads: numHeads,
headDim: hiddenSize / numHeads, headDim: hiddenSize / numHeads,
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)), patchSize: patchSize,
imageSize: int(c.Uint("vision.image_size", 560)), numChannels: numChannels,
patchSize: patchSize, eps: eps,
numChannels: int(c.Uint("vision.num_channels", 3)), // not set ropeTheta: ropeTheta,
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), spatialMergeSize: spatialMergeSize,
ropeTheta: ropeTheta, windowSize: windowSize,
outHiddenSize: outHiddenSize, temporalPatchSize: temporalPatchSize,
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
windowSize: int(c.Uint("vision.window_size", 112)),
}, },
} }
} }