diff --git a/ml/nn/attention.go b/ml/nn/attention.go index e33ad08dc..a3f43a1ea 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -7,18 +7,6 @@ import ( "github.com/ollama/ollama/ml" ) -type AttentionOption func(*attentionOptions) - -type attentionOptions struct { - mask ml.Tensor -} - -func WithMask(mask ml.Tensor) AttentionOption { - return func(opts *attentionOptions) { - opts.mask = mask - } -} - // Attention implements scaled dot-product attention for transformer models: // Attention(Q, K, V) = softmax(QK^T/√d_k)V // @@ -33,12 +21,7 @@ func WithMask(mask ml.Tensor) AttentionOption { // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache, opts ...AttentionOption) ml.Tensor { - options := &attentionOptions{} - for _, opt := range opts { - opt(options) - } - +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -63,9 +46,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache if cache != nil { key, value, mask = cache.Get(ctx) } - if options.mask != nil { - mask = options.mask - } // Only use the fast SDPA implementation if we have a cache, since that's what // will do any expected backend-specific transformations for us diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 87a1f5a27..52e4b7ce0 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -80,7 +80,18 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml mask = blockDiagonalMask(ctx, query.Dim(2), bounds, opts.numHeads) } - attention := nn.Attention(ctx, query, key, value, scale, nil, nn.WithMask(mask)) + // Scaled dot-product attention + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) + kq = kq.Scale(ctx, scale) + if mask != nil { + kq = kq.Add(ctx, mask) + } + kq = kq.Softmax(ctx) + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention)