diff --git a/kvcache/cache.go b/kvcache/cache.go index 5d8b2f9b5..2541f7c16 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -29,6 +29,17 @@ type Cache interface { // cache implementation used. Put(ctx ml.Context, key, value ml.Tensor) + // SetConfig controls optimizations (mostly backend-specific) that may transform + // the output of the cache to work better with specific kernels. If not called, + // the backend settings will be used. This works well when calling Attention. + // + // The config can be overridden by models, especially if they require vanilla + // output when implementing their own version of attention. To do this, pass + // an empty ml.CacheConfig. + // + // Most models will not need to use this. + SetConfig(ml.CacheConfig) + // ** cache management ** // Init sets up runtime parameters diff --git a/kvcache/causal.go b/kvcache/causal.go index 69068439e..1d4daf809 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -22,6 +22,9 @@ type Causal struct { Capacity int32 windowSize int32 + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -75,14 +78,34 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding == 0 { + c.config.CachePadding = 1 + } + c.DType = dtype - c.Capacity = capacity - c.cells = make([]cacheCell, capacity) + c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) + c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend c.cacheCtx = backend.NewContext() } +func (c *Causal) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *Causal) Close() { c.cacheCtx.Close() } @@ -157,36 +180,73 @@ func (c *Causal) findStartLoc() (int, error) { return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) } +func roundDown(length, pad int) int { + return (length / pad) * pad +} + +func roundUp(length, pad int) int { + return ((length + pad - 1) / pad) * pad +} + // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do padding, which is required for flash attention - len := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*len) + // TODO(jessegross): This does not do mask padding, which is required for flash attention + // Align and pad the cache range as required by the backend + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) + c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + + length := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, c.curBatchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || c.cells[j].pos < positions[i]-c.windowSize { - mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } } - return ctx.FromFloatSlice(mask, len, c.curBatchSize) + return ctx.FromFloatSlice(mask, length, c.curBatchSize) } -func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { - for _, obj := range objs { - if obj == nil { +func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { + for i := range c.keys { + if c.keys[i] == nil { continue } - srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) - dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + key := c.keys[i] - ctx.Forward(srcView.Copy(ctx, dstView)) + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + + value := c.values[i] + var vSrcView, vDstView ml.Tensor + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + } + + ctx.Forward( + kSrcView.Copy(ctx, kDstView), + vSrcView.Copy(ctx, vDstView), + ) } } @@ -238,8 +298,7 @@ func (c *Causal) defrag() { pendingLen++ break } else { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } } @@ -263,8 +322,7 @@ func (c *Causal) defrag() { } if pendingLen > 0 { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } @@ -305,35 +363,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer] - key = key.View(ctx, key.Stride(2)*c.curCellRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - c.curMask.Dim(0), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + cachedSize := c.curMask.Dim(0) + + key = key.View(ctx, rowSize*c.curCellRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + cachedSize, ) - value = value.View(ctx, key.Stride(2)*c.curCellRange.min, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - c.curMask.Dim(0), - ) + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + value = value.View(ctx, elemSize*c.curCellRange.min, + cachedSize, value.Stride(1), + vHeadDim, value.Stride(2), + numKVHeads, + ) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + value = value.View(ctx, rowSize*c.curCellRange.min, + vHeadDim, value.Stride(1), + numKVHeads, value.Stride(2), + cachedSize, + ) + } return key, value, c.curMask } func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { - if c.curBatchSize != key.Dim(2) { - panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + kHeadDim := key.Dim(0) + vHeadDim := value.Dim(0) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + + if c.curBatchSize != batchSize { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) } if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + + if c.config.PermutedV { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + } else { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + } } - ctx.Forward( - key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), - value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))), - ) + rowSize := c.keys[c.curLayer].Stride(2) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + + if c.config.PermutedV { + elemSize := c.values[c.curLayer].Stride(0) + + value = value.Permute(ctx, 1, 2, 0, 3) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + } else { + rowSize := c.values[c.curLayer].Stride(2) + + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + } } func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { @@ -389,9 +485,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } - key = key.View(ctx, key.Stride(2)*seqRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*seqRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), size, ) diff --git a/kvcache/encoder.go b/kvcache/encoder.go index b85b1046a..c55da2b4a 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -1,6 +1,8 @@ package kvcache import ( + "fmt" + "github.com/ollama/ollama/ml" ) @@ -11,6 +13,9 @@ import ( // // Not currently safe for multiple sequences type EncoderCache struct { + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache { } func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { + panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) + } + c.cacheCtx = backend.NewContext() } +func (c *EncoderCache) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *EncoderCache) Close() { c.cacheCtx.Close() } @@ -75,6 +100,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { c.encoderPos = c.curPos c.encoderCached = true + if c.config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3) + } + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 2d4c1089a..76956a88a 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) } } +func (c *WrapperCache) SetConfig(config ml.CacheConfig) { + for _, cache := range c.caches { + cache.SetConfig(config) + } +} + func (c *WrapperCache) Close() { for _, cache := range c.caches { cache.Close() diff --git a/ml/backend.go b/ml/backend.go index 07bc75b64..ccab915c7 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -27,6 +27,27 @@ type Backend interface { SystemInfo() string } +// BackendCacheConfig should be implemented by backends that need special output +// from the cache to meet specific requirements. It is frequently implemented in +// conjunction with ScaledDotProductAttention. +type BackendCacheConfig interface { + CacheConfig() CacheConfig +} + +// CacheConfig controls optimizations (mostly backend-specific) that may transform +// the output the cache to work better with specific kernels. +type CacheConfig struct { + // CachePadding specifies the multiple for the number of tokens of cache history + // that will be returned from cache Get for k, v and mask. The capacity of the + // cache itself will also be increased to a multiple of this size if needed. + CachePadding int + + // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put + // and return the permuted version via Get. This uses the cache copy operation + // to avoid a Contiguous call on the permuted tensor. + PermutedV bool +} + // BackendParams controls how the backend loads and executes models type BackendParams struct { // NumThreads sets the number of threads to use if running on the CPU @@ -116,6 +137,10 @@ type Tensor interface { // operation equivalent to following code on a tensor named // query: // +// query = query.Permute(ctx, 0, 2, 1, 3) +// key = key.Permute(ctx, 0, 2, 1, 3) +// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) +// // kq := key.MulmatFullPrec(ctx, query) // // kq = kq.Scale(ctx, scale) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 7f91990c3..bddaad463 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -247,6 +247,10 @@ func (b *Backend) NewContext() ml.Context { } } +func (b *Backend) CacheConfig() ml.CacheConfig { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} +} + type Context struct { b *Backend ctx *C.struct_ggml_context @@ -661,7 +665,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kqMask = mask.(*Tensor).t } - kq := key.MulmatFullPrec(ctx, t) + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + + kq := key.MulmatFullPrec(ctx, query) kq = &Tensor{ t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), } diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 4f0c9fa14..a3f43a1ea 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -3,6 +3,7 @@ package nn import ( "fmt" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" ) @@ -11,40 +12,50 @@ import ( // // Parameters: // - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] -// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] -// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] -// - mask: Optional attention mask that is added to the attention score. If -// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] +// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value // // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + if key != nil && value != nil { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if key.Dim(1) != value.Dim(1) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if cache != nil { + cache.Put(ctx, key, value) + } + } else if cache == nil { + panic("key & value tensors must be provided if cache is nil") } - if mask != nil && query.Dim(1) != mask.Dim(1) { - panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + var mask ml.Tensor + if cache != nil { + key, value, mask = cache.Get(ctx) } - if key.Dim(1) != value.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) - } - - if mask != nil && key.Dim(1) != mask.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) - } - - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) - } - - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + // Only use the fast SDPA implementation if we have a cache, since that's what + // will do any expected backend-specific transformations for us + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) } else { + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6106af867..9bf6f4979 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) - - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 9b35a2628..743f4c32d 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) { TextModel: newTextModel(c), } - m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + encoderCache := kvcache.NewEncoderCache() + encoderCache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift)) return &m, nil } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 003bf9cbf..e294b4c71 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, key, value) - key, value, mask := cache.Get(ctx) - - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - // This will only get called for layers in the cache, which are just the self attention layers + // This will only get called for layers in the causal cache, which are just the self attention layers return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil } @@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value, mask ml.Tensor + var key, value ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) cache.Put(ctx, key, value) - } else { - key, value, mask = cache.Get(ctx) } - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + key, value, _ = cache.Get(ctx) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scaleFactor) + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention)