move sdpa to model forward pass
This commit is contained in:
parent
ca981c8a49
commit
b68af0370f
@ -7,18 +7,6 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type AttentionOption func(*attentionOptions)
|
||||
|
||||
type attentionOptions struct {
|
||||
mask ml.Tensor
|
||||
}
|
||||
|
||||
func WithMask(mask ml.Tensor) AttentionOption {
|
||||
return func(opts *attentionOptions) {
|
||||
opts.mask = mask
|
||||
}
|
||||
}
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
@ -33,12 +21,7 @@ func WithMask(mask ml.Tensor) AttentionOption {
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
|
||||
options := &attentionOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
@ -63,9 +46,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
if cache != nil {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
if options.mask != nil {
|
||||
mask = options.mask
|
||||
}
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
@ -80,7 +80,18 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
||||
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
|
||||
// Scaled dot-product attention
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
|
Loading…
x
Reference in New Issue
Block a user