move sdpa to model forward pass

This commit is contained in:
Bruce MacDonald 2025-05-01 11:51:32 -07:00
parent ca981c8a49
commit b68af0370f
2 changed files with 13 additions and 22 deletions

View File

@ -7,18 +7,6 @@ import (
"github.com/ollama/ollama/ml"
)
type AttentionOption func(*attentionOptions)
type attentionOptions struct {
mask ml.Tensor
}
func WithMask(mask ml.Tensor) AttentionOption {
return func(opts *attentionOptions) {
opts.mask = mask
}
}
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
@ -33,12 +21,7 @@ func WithMask(mask ml.Tensor) AttentionOption {
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
options := &attentionOptions{}
for _, opt := range opts {
opt(options)
}
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@ -63,9 +46,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
if cache != nil {
key, value, mask = cache.Get(ctx)
}
if options.mask != nil {
mask = options.mask
}
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us

View File

@ -80,7 +80,18 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
}
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)