ml: Abstract attention out of model definitions
There are two benefits to doing this: - Provide a library function that models can use, reducing code for each model implementation - Enables a single place to drop in optimized implementations of attention based on the backend or other factors. One is provided for GGML. On CUDA this improves token generation rate by about 3%. It does not have a significant effect on Metal. Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
parent
2192a28eed
commit
f53f4198c3
@ -111,6 +111,26 @@ type Tensor interface {
|
|||||||
Copy(ctx Context, t2 Tensor) Tensor
|
Copy(ctx Context, t2 Tensor) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScaledDotProductAttention implements a fused attention
|
||||||
|
// operation equivalent to following code on a tensor named
|
||||||
|
// query:
|
||||||
|
//
|
||||||
|
// kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
//
|
||||||
|
// kq = kq.Scale(ctx, scale)
|
||||||
|
//
|
||||||
|
// if mask != nil {
|
||||||
|
// kq = kq.Add(ctx, mask)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// kq = kq.Softmax(ctx)
|
||||||
|
//
|
||||||
|
// kqv := value.Mulmat(ctx, kq)
|
||||||
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
type ScaledDotProductAttention interface {
|
||||||
|
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
type number interface {
|
type number interface {
|
||||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||||
|
@ -651,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||||
|
var kqMask *C.struct_ggml_tensor
|
||||||
|
if mask != nil {
|
||||||
|
kqMask = mask.(*Tensor).t
|
||||||
|
}
|
||||||
|
|
||||||
|
kq := key.MulmatFullPrec(ctx, t)
|
||||||
|
kq = &Tensor{
|
||||||
|
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Backend) SystemInfo() string {
|
func (b *Backend) SystemInfo() string {
|
||||||
var compiler string
|
var compiler string
|
||||||
switch C.get_compiler() {
|
switch C.get_compiler() {
|
||||||
|
59
ml/nn/attention.go
Normal file
59
ml/nn/attention.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Attention implements scaled dot-product attention for transformer models:
|
||||||
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: Context for tensor operations
|
||||||
|
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||||
|
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||||
|
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||||
|
// - mask: Optional attention mask that is added to the attention score. If
|
||||||
|
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||||
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
//
|
||||||
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
|
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||||
|
if query.Dim(0) != key.Dim(0) {
|
||||||
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||||
|
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(1) != value.Dim(0) {
|
||||||
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||||
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(2) != value.Dim(2) {
|
||||||
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||||
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||||
|
} else {
|
||||||
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
|
||||||
|
kq = kq.Scale(ctx, scale)
|
||||||
|
if mask != nil {
|
||||||
|
kq = kq.Add(ctx, mask)
|
||||||
|
}
|
||||||
|
kq = kq.Softmax(ctx)
|
||||||
|
|
||||||
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
}
|
||||||
|
}
|
@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
kq := k.MulmatFullPrec(ctx, q)
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||||
kq = kq.Add(ctx, mask)
|
|
||||||
kq = kq.Softmax(ctx)
|
|
||||||
|
|
||||||
kqv := v.Mulmat(ctx, kq)
|
|
||||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
return sa.Output.Forward(ctx, kqv)
|
||||||
|
@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
scores := key.MulmatFullPrec(ctx, query)
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||||
scores = scores.Add(ctx, mask)
|
|
||||||
scores = scores.Softmax(ctx)
|
|
||||||
|
|
||||||
attention := value.Mulmat(ctx, scores)
|
|
||||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
|
||||||
var key, value ml.Tensor
|
var key, value, mask ml.Tensor
|
||||||
if crossAttentionStates != nil {
|
if crossAttentionStates != nil {
|
||||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||||
|
|
||||||
@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
|
|
||||||
cache.Put(ctx, key, value)
|
cache.Put(ctx, key, value)
|
||||||
} else {
|
} else {
|
||||||
key, value, _ = cache.Get(ctx)
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
scores := key.Mulmat(ctx, query)
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||||
scores = scores.Softmax(ctx)
|
|
||||||
|
|
||||||
attention := value.Mulmat(ctx, scores)
|
|
||||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return ca.Output.Forward(ctx, attention)
|
return ca.Output.Forward(ctx, attention)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user