block attention
This commit is contained in:
parent
104f802df1
commit
ff1f74534b
@ -7,6 +7,18 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AttentionOption func(*attentionOptions)
|
||||||
|
|
||||||
|
type attentionOptions struct {
|
||||||
|
mask ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithMask(mask ml.Tensor) AttentionOption {
|
||||||
|
return func(opts *attentionOptions) {
|
||||||
|
opts.mask = mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Attention implements scaled dot-product attention for transformer models:
|
// Attention implements scaled dot-product attention for transformer models:
|
||||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||||
//
|
//
|
||||||
@ -21,7 +33,12 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
//
|
//
|
||||||
// Attention output with shape [d_v, heads, seq_len_q]
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
|
||||||
|
options := &attentionOptions{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
if key != nil && value != nil {
|
if key != nil && value != nil {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
@ -46,6 +63,9 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
|||||||
if cache != nil {
|
if cache != nil {
|
||||||
key, value, mask = cache.Get(ctx)
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
|
if options.mask != nil {
|
||||||
|
mask = options.mask
|
||||||
|
}
|
||||||
|
|
||||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||||
// will do any expected backend-specific transformations for us
|
// will do any expected backend-specific transformations for us
|
||||||
|
@ -22,6 +22,37 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te
|
|||||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
||||||
|
// Create a flat slice for the mask (all -inf initially to block all attention)
|
||||||
|
flat := make([]float32, seqLength*seqLength)
|
||||||
|
for i := range flat {
|
||||||
|
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill in the mask with zeros for tokens that CAN attend to each other
|
||||||
|
for i := 1; i < len(bounds); i++ {
|
||||||
|
start := bounds[i-1]
|
||||||
|
end := bounds[i]
|
||||||
|
|
||||||
|
// Enable attention within this sequence block by setting values to 0
|
||||||
|
for row := start; row < end; row++ {
|
||||||
|
for col := start; col < end; col++ {
|
||||||
|
idx := row*seqLength + col
|
||||||
|
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
||||||
|
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
}
|
||||||
|
|
||||||
type VisionSelfAttention struct {
|
type VisionSelfAttention struct {
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
@ -29,7 +60,7 @@ type VisionSelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_out"`
|
Output *nn.Linear `gguf:"attn_out"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, bounds []int, opts *VisionModelOptions) ml.Tensor {
|
||||||
query := sa.Query.Forward(ctx, hiddenStates)
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
key := sa.Key.Forward(ctx, hiddenStates)
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
value := sa.Value.Forward(ctx, hiddenStates)
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
@ -43,8 +74,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||||||
|
|
||||||
// Scale factor for scaled dot-product attention
|
// Scale factor for scaled dot-product attention
|
||||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||||
|
mask := blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, scale, nil)
|
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
@ -73,10 +105,10 @@ type VisionEncoderLayer struct {
|
|||||||
MLP *VisionMLP
|
MLP *VisionMLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cuSeqLens []int, opts *VisionModelOptions) ml.Tensor {
|
||||||
residual := hiddenStates
|
residual := hiddenStates
|
||||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, cuSeqLens, opts)
|
||||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
residual = hiddenStates
|
residual = hiddenStates
|
||||||
@ -171,7 +203,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
|
|
||||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||||
|
|
||||||
windowIndex := m.windowIndex(ctx, grid)
|
windowIndex, bounds := m.windowIndex(ctx, grid)
|
||||||
|
|
||||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||||
|
|
||||||
@ -190,13 +222,21 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
|
|
||||||
// Apply encoder layers
|
// Apply encoder layers
|
||||||
for _, layer := range m.Layers {
|
for _, layer := range m.Layers {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, bounds, m.VisionModelOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
return m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
// windowIndex divides the grid into windows and returns:
|
||||||
|
// 1. A tensor containing flattened indices of all grid points organized by windows
|
||||||
|
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||||
|
// in the flattened representation, scaled by spatialMergeSize squared
|
||||||
|
//
|
||||||
|
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||||
|
// positions for each window, allowing downstream processing to identify
|
||||||
|
// window boundaries in the tensor data.
|
||||||
|
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||||
|
|
||||||
llmGridH := grid.Height / m.spatialMergeSize
|
llmGridH := grid.Height / m.spatialMergeSize
|
||||||
@ -209,6 +249,10 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
|||||||
// Initialize index_new slice
|
// Initialize index_new slice
|
||||||
var index []int32
|
var index []int32
|
||||||
|
|
||||||
|
// Initialize bounds with the first element as 0
|
||||||
|
bounds := []int{0}
|
||||||
|
totalSeqLen := 0
|
||||||
|
|
||||||
// Process each window without padding
|
// Process each window without padding
|
||||||
for wh := range numWindowsH {
|
for wh := range numWindowsH {
|
||||||
for ww := range numWindowsW {
|
for ww := range numWindowsW {
|
||||||
@ -218,12 +262,19 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
|||||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
||||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
||||||
|
|
||||||
|
// Calculate sequence length for this window
|
||||||
|
seqLen := (hEnd - hStart) * (wEnd - wStart)
|
||||||
|
|
||||||
// Collect indices for this window
|
// Collect indices for this window
|
||||||
for h := hStart; h < hEnd; h++ {
|
for h := hStart; h < hEnd; h++ {
|
||||||
for w := wStart; w < wEnd; w++ {
|
for w := wStart; w < wEnd; w++ {
|
||||||
index = append(index, int32(h*llmGridW+w))
|
index = append(index, int32(h*llmGridW+w))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update total sequence length and append to cuWindowSeqlens
|
||||||
|
totalSeqLen += seqLen
|
||||||
|
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,7 +283,7 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return t
|
return t, bounds
|
||||||
}
|
}
|
||||||
|
|
||||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||||
|
Loading…
x
Reference in New Issue
Block a user