Update model_vision.go

This commit is contained in:
Bruce MacDonald 2025-04-28 09:41:04 -07:00
parent 0f0136d419
commit 04936b719f

View File

@ -1,7 +1,6 @@
package qwen25vl package qwen25vl
import ( import (
"fmt"
"math" "math"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
@ -11,6 +10,16 @@ import (
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
// VisionSelfAttention implements self-attention for the Qwen vision model // VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct { type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
@ -20,7 +29,7 @@ type VisionSelfAttention struct {
} }
// Forward computes self-attention for the vision model // Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates) query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates)
@ -29,28 +38,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, p
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
config := ml.RoPEConfig{ query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
Dim: uint32(opts.headDim / 2), key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
Type: ml.RopeTypeMRoPE,
Base: opts.ropeTheta,
Scale: 1.0,
YarnConfig: ml.DefaultYarnConfig(128000),
}
query = query.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
key = key.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
// Scale factor for scaled dot-product attention // Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim)) scale := 1.0 / math.Sqrt(float64(opts.headDim))
@ -87,10 +76,10 @@ type VisionEncoderLayer struct {
} }
// Forward computes an encoder layer for the vision model // Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positionIDs, opts) hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual) hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates residual = hiddenStates
@ -188,22 +177,22 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
m.patchSize, // patch size, e.g., 14 m.patchSize, // patch size, e.g., 14
) )
// TODO: working here positionEmbedding := m.positionalEmbedding(ctx, grid)
m.rotaryEmbedding(ctx, grid) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
// // Apply encoder layers // Apply encoder layers
// for _, layer := range m.Layers { for _, layer := range m.Layers {
// hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions) hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
// } }
// hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps) // hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
return hiddenStates return hiddenStates
} }
// rotaryEmbedding generates rotary position embeddings for attention mechanisms // positionalEmbedding generates rotary position embeddings for attention mechanisms
// This implements rotary embeddings using spatial merging patterns for grid-based // This implements rotary embeddings using spatial merging patterns for grid-based
// vision transformers // vision transformers
func (m *VisionModel) rotaryEmbedding(ctx ml.Context, grid *Grid) (ml.Tensor, ml.Tensor) { func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
// Configuration parameters // Configuration parameters
dim := 80 / 2 // Head dimension divided by 2 dim := 80 / 2 // Head dimension divided by 2
freq := dim / 2 // Frequency dimension (half of head dimension) freq := dim / 2 // Frequency dimension (half of head dimension)
@ -246,14 +235,11 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context, grid *Grid) (ml.Tensor, ml
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge) pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
// Use position indices to look up corresponding frequency values // Use position indices to look up corresponding frequency values
out := freqs.Rows(ctx, pos) positionalEmbedding := freqs.Rows(ctx, pos)
out = out.Reshape(ctx, out.Dim(0)*2, out.Dim(1)/2) positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
positionalEmbedding = positionalEmbedding.Concat(ctx, positionalEmbedding, 0)
fmt.Println("out", out.Shape()) return positionalEmbedding
fmt.Println(ml.Dump(ctx, out))
// TODO: return cos and sin tensors for rotary embedding
return nil, nil
} }
// newVisionModel creates a new instance of the Qwen vision model // newVisionModel creates a new instance of the Qwen vision model