move sdpa to model forward pass
This commit is contained in:
parent
ca981c8a49
commit
b68af0370f
@ -7,18 +7,6 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AttentionOption func(*attentionOptions)
|
|
||||||
|
|
||||||
type attentionOptions struct {
|
|
||||||
mask ml.Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithMask(mask ml.Tensor) AttentionOption {
|
|
||||||
return func(opts *attentionOptions) {
|
|
||||||
opts.mask = mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attention implements scaled dot-product attention for transformer models:
|
// Attention implements scaled dot-product attention for transformer models:
|
||||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||||
//
|
//
|
||||||
@ -33,12 +21,7 @@ func WithMask(mask ml.Tensor) AttentionOption {
|
|||||||
// Returns:
|
// Returns:
|
||||||
//
|
//
|
||||||
// Attention output with shape [d_v, heads, seq_len_q]
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor {
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
options := &attentionOptions{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(options)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key != nil && value != nil {
|
if key != nil && value != nil {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
@ -63,9 +46,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
|||||||
if cache != nil {
|
if cache != nil {
|
||||||
key, value, mask = cache.Get(ctx)
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
if options.mask != nil {
|
|
||||||
mask = options.mask
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||||
// will do any expected backend-specific transformations for us
|
// will do any expected backend-specific transformations for us
|
||||||
|
@ -80,7 +80,18 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||||||
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads)
|
||||||
}
|
}
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask))
|
// Scaled dot-product attention
|
||||||
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
kq = kq.Scale(ctx, scale)
|
||||||
|
if mask != nil {
|
||||||
|
kq = kq.Add(ctx, mask)
|
||||||
|
}
|
||||||
|
kq = kq.Softmax(ctx)
|
||||||
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user