full attention layers

This commit is contained in:
Bruce MacDonald 2025-05-01 10:59:46 -07:00
parent bde6b46ce9
commit 359e1d5b19
2 changed files with 34 additions and 15 deletions

View File

@ -12,17 +12,18 @@ type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []uint32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
@ -48,6 +49,7 @@ func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv["qwen25vl.vision.window_size"] = q.VisionModel.WindowSize
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e5)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
return kv

View File

@ -3,6 +3,7 @@ package qwen25vl
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@ -74,7 +75,10 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
mask := blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
var mask ml.Tensor
if bounds != nil {
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
}
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
@ -128,6 +132,7 @@ type VisionModelOptions struct {
ropeTheta float32
spatialMergeSize int
windowSize int
fullAttnBlocks []int
temporalPatchSize int
}
@ -221,8 +226,12 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
// Apply encoder layers
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, i) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
} else {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
}
}
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
@ -342,9 +351,10 @@ func newVisionModel(c fs.Config) *VisionModel {
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
fullAttnBlocks := c.Ints("qwen25vl.vision.fullatt_block_indexes", []int32{7, 15, 23, 31})
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
return &VisionModel{
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
@ -359,4 +369,11 @@ func newVisionModel(c fs.Config) *VisionModel {
temporalPatchSize: temporalPatchSize,
},
}
for i := range fullAttnBlocks {
// full attention block indexes have to be converted to int for use with the slices package
model.fullAttnBlocks = append(model.fullAttnBlocks, int(fullAttnBlocks[i]))
}
return model
}