diff --git a/ml/nn/attention.go b/ml/nn/attention.go index a3f43a1ea..e33ad08dc 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -7,6 +7,18 @@ import ( "github.com/ollama/ollama/ml" ) +type AttentionOption func(*attentionOptions) + +type attentionOptions struct { + mask ml.Tensor +} + +func WithMask(mask ml.Tensor) AttentionOption { + return func(opts *attentionOptions) { + opts.mask = mask + } +} + // Attention implements scaled dot-product attention for transformer models: // Attention(Q, K, V) = softmax(QK^T/√d_k)V // @@ -21,7 +33,12 @@ import ( // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor { + options := &attentionOptions{} + for _, opt := range opts { + opt(options) + } + if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -46,6 +63,9 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache if cache != nil { key, value, mask = cache.Get(ctx) } + if options.mask != nil { + mask = options.mask + } // Only use the fast SDPA implementation if we have a cache, since that's what // will do any expected backend-specific transformations for us diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 306de8780..d1f13eab5 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -22,6 +22,37 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) } +func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { + // Create a flat slice for the mask (all -inf initially to block all attention) + flat := make([]float32, seqLength*seqLength) + for i := range flat { + flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention + } + + // Fill in the mask with zeros for tokens that CAN attend to each other + for i := 1; i < len(bounds); i++ { + start := bounds[i-1] + end := bounds[i] + + // Enable attention within this sequence block by setting values to 0 + for row := start; row < end; row++ { + for col := start; col < end; col++ { + idx := row*seqLength + col + flat[idx] = 0.0 // 0 allows attention, -inf blocks it + } + } + } + + mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + if err != nil { + panic(err) + } + // Reshape to match [seqLength, seqLength, 1] for broadcasting + mask = mask.Reshape(ctx, seqLength, seqLength, 1) + + return mask +} + type VisionSelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -29,7 +60,7 @@ type VisionSelfAttention struct { Output *nn.Linear `gguf:"attn_out"` } -func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates) @@ -43,8 +74,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) + mask := blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads) - attention := nn.Attention(ctx, query, key, value, scale, nil) + attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask)) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) @@ -73,10 +105,10 @@ type VisionEncoderLayer struct { MLP *VisionMLP } -func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor { residual := hiddenStates hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts) hiddenStates = hiddenStates.Add(ctx, residual) residual = hiddenStates @@ -171,7 +203,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) positionEmbedding := m.positionalEmbedding(ctx, grid) - windowIndex := m.windowIndex(ctx, grid) + windowIndex, bounds := m.windowIndex(ctx, grid) spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize @@ -190,13 +222,21 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) // Apply encoder layers for _, layer := range m.Layers { - hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions) + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions) } return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) } -func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { +// windowIndex divides the grid into windows and returns: +// 1. A tensor containing flattened indices of all grid points organized by windows +// 2. A slice of boundaries that mark where each window's data begins and ends +// in the flattened representation, scaled by spatialMergeSize squared +// +// The boundaries slice always starts with 0 and contains cumulative ending +// positions for each window, allowing downstream processing to identify +// window boundaries in the tensor data. +func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize llmGridH := grid.Height / m.spatialMergeSize @@ -209,6 +249,10 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { // Initialize index_new slice var index []int32 + // Initialize bounds with the first element as 0 + bounds := []int{0} + totalSeqLen := 0 + // Process each window without padding for wh := range numWindowsH { for ww := range numWindowsW { @@ -218,12 +262,19 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { hEnd := min(hStart+vitMergerWindowSize, llmGridH) wEnd := min(wStart+vitMergerWindowSize, llmGridW) + // Calculate sequence length for this window + seqLen := (hEnd - hStart) * (wEnd - wStart) + // Collect indices for this window for h := hStart; h < hEnd; h++ { for w := wStart; w < wEnd; w++ { index = append(index, int32(h*llmGridW+w)) } } + + // Update total sequence length and append to cuWindowSeqlens + totalSeqLen += seqLen + bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0]) } } @@ -232,7 +283,7 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { panic(err) } - return t + return t, bounds } // positionalEmbedding generates rotary position embeddings for attention mechanisms