block attention
This commit is contained in:
parent
104f802df1
commit
ff1f74534b
@ -7,6 +7,18 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type AttentionOption func(*attentionOptions)
|
||||
|
||||
type attentionOptions struct {
|
||||
mask ml.Tensor
|
||||
}
|
||||
|
||||
func WithMask(mask ml.Tensor) AttentionOption {
|
||||
return func(opts *attentionOptions) {
|
||||
opts.mask = mask
|
||||
}
|
||||
}
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
@ -21,7 +33,12 @@ import (
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
|
||||
options := &attentionOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
@ -46,6 +63,9 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
if cache != nil {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
if options.mask != nil {
|
||||
mask = options.mask
|
||||
}
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
@ -22,6 +22,37 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
||||
// Create a flat slice for the mask (all -inf initially to block all attention)
|
||||
flat := make([]float32, seqLength*seqLength)
|
||||
for i := range flat {
|
||||
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
|
||||
}
|
||||
|
||||
// Fill in the mask with zeros for tokens that CAN attend to each other
|
||||
for i := 1; i < len(bounds); i++ {
|
||||
start := bounds[i-1]
|
||||
end := bounds[i]
|
||||
|
||||
// Enable attention within this sequence block by setting values to 0
|
||||
for row := start; row < end; row++ {
|
||||
for col := start; col < end; col++ {
|
||||
idx := row*seqLength + col
|
||||
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
||||
|
||||
return mask
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
@ -29,7 +60,7 @@ type VisionSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
@ -43,8 +74,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
mask := blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, scale, nil)
|
||||
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
@ -73,10 +105,10 @@ type VisionEncoderLayer struct {
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
@ -171,7 +203,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||
|
||||
windowIndex := m.windowIndex(ctx, grid)
|
||||
windowIndex, bounds := m.windowIndex(ctx, grid)
|
||||
|
||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||
|
||||
@ -190,13 +222,21 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
|
||||
// Apply encoder layers
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// windowIndex divides the grid into windows and returns:
|
||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||
// in the flattened representation, scaled by spatialMergeSize squared
|
||||
//
|
||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||
// positions for each window, allowing downstream processing to identify
|
||||
// window boundaries in the tensor data.
|
||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||
|
||||
llmGridH := grid.Height / m.spatialMergeSize
|
||||
@ -209,6 +249,10 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// Initialize index_new slice
|
||||
var index []int32
|
||||
|
||||
// Initialize bounds with the first element as 0
|
||||
bounds := []int{0}
|
||||
totalSeqLen := 0
|
||||
|
||||
// Process each window without padding
|
||||
for wh := range numWindowsH {
|
||||
for ww := range numWindowsW {
|
||||
@ -218,12 +262,19 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
||||
|
||||
// Calculate sequence length for this window
|
||||
seqLen := (hEnd - hStart) * (wEnd - wStart)
|
||||
|
||||
// Collect indices for this window
|
||||
for h := hStart; h < hEnd; h++ {
|
||||
for w := wStart; w < wEnd; w++ {
|
||||
index = append(index, int32(h*llmGridW+w))
|
||||
}
|
||||
}
|
||||
|
||||
// Update total sequence length and append to cuWindowSeqlens
|
||||
totalSeqLen += seqLen
|
||||
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,7 +283,7 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
return t, bounds
|
||||
}
|
||||
|
||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||
|
Loading…
x
Reference in New Issue
Block a user