block attention

This commit is contained in:
Bruce MacDonald 2025-04-30 16:43:37 -07:00
parent 104f802df1
commit ff1f74534b
2 changed files with 80 additions and 9 deletions

View File

@ -7,6 +7,18 @@ import (
"github.com/ollama/ollama/ml"
)
type AttentionOption func(*attentionOptions)
type attentionOptions struct {
mask ml.Tensor
}
func WithMask(mask ml.Tensor) AttentionOption {
return func(opts *attentionOptions) {
opts.mask = mask
}
}
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
@ -21,7 +33,12 @@ import (
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
options := &attentionOptions{}
for _, opt := range opts {
opt(options)
}
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@ -46,6 +63,9 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
if cache != nil {
key, value, mask = cache.Get(ctx)
}
if options.mask != nil {
mask = options.mask
}
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us

View File

@ -22,6 +22,37 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
// Create a flat slice for the mask (all -inf initially to block all attention)
flat := make([]float32, seqLength*seqLength)
for i := range flat {
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
}
// Fill in the mask with zeros for tokens that CAN attend to each other
for i := 1; i < len(bounds); i++ {
start := bounds[i-1]
end := bounds[i]
// Enable attention within this sequence block by setting values to 0
for row := start; row < end; row++ {
for col := start; col < end; col++ {
idx := row*seqLength + col
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
}
}
}
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
if err != nil {
panic(err)
}
// Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
return mask
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
@ -29,7 +60,7 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
@ -43,8 +74,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
mask := blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
attention := nn.Attention(ctx, query, key, value, scale, nil)
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
@ -73,10 +105,10 @@ type VisionEncoderLayer struct {
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
@ -171,7 +203,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
positionEmbedding := m.positionalEmbedding(ctx, grid)
windowIndex := m.windowIndex(ctx, grid)
windowIndex, bounds := m.windowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
@ -190,13 +222,21 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
// Apply encoder layers
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
}
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
}
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
// windowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared
//
// The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data.
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
llmGridH := grid.Height / m.spatialMergeSize
@ -209,6 +249,10 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
// Initialize index_new slice
var index []int32
// Initialize bounds with the first element as 0
bounds := []int{0}
totalSeqLen := 0
// Process each window without padding
for wh := range numWindowsH {
for ww := range numWindowsW {
@ -218,12 +262,19 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Calculate sequence length for this window
seqLen := (hEnd - hStart) * (wEnd - wStart)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
}
}
// Update total sequence length and append to cuWindowSeqlens
totalSeqLen += seqLen
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
}
}
@ -232,7 +283,7 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
panic(err)
}
return t
return t, bounds
}
// positionalEmbedding generates rotary position embeddings for attention mechanisms