From 96510b935340842654b631a34f9d9a48c853bdfb Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 14 Feb 2025 14:55:30 -0800 Subject: [PATCH] model: document qwen2 forward pass --- model/models/qwen2/model.go | 156 +++++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 65 deletions(-) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index f22a1c305..de419d746 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -10,10 +10,15 @@ import ( ) type Options struct { - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` - ctxLen, hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` + contextLength int + hiddenSize int + numAttnHeads int + numKVHeads int + modelEpsilon float32 + ropeBaseFreq float32 + ropeFreqScale float32 + ropeDimensions uint32 } type Model struct { @@ -42,14 +47,14 @@ func New(c ml.Config) (model.Model, error) { ), Layers: make([]Layer, c.Uint("block_count")), Options: &Options{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ctxLen: int(c.Uint("context_length")), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count", 64), + hiddenSize: int(c.Uint("embedding_length")), + numAttnHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + modelEpsilon: c.Float("attention.layer_norm_rms_epsilon"), + contextLength: int(c.Uint("context_length")), + ropeBaseFreq: c.Float("rope.freq_base"), + ropeFreqScale: c.Float("rope.freq_scale", 1), + ropeDimensions: c.Uint("rope.dimension_count", 64), }, } @@ -58,21 +63,24 @@ func New(c ml.Config) (model.Model, error) { return m, nil } +// Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key.RoPE( ctx, ml.RopeConfig{ PositionIDs: shift, RopeFactors: m.Options.RopeFactors, - RopeDim: m.Options.ropeDim, + RopeDim: m.Options.ropeDimensions, RopeType: ml.RopeTypeNeoX, - OrigCtxLen: m.Options.ctxLen, - RopeBase: m.Options.ropeBase, - RopeScale: m.Options.ropeScale, + OrigCtxLen: m.Options.contextLength, + RopeBase: m.Options.ropeBaseFreq, + RopeScale: m.Options.ropeFreqScale, }, ), nil } +// SelfAttention implements the multi-head self-attention mechanism +// with separate projections for query, key, value and output transformations type SelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -81,49 +89,59 @@ type SelfAttention struct { } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + // Initialize dimensions and configuration batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads - rc := ml.RopeConfig{ + headDimension := opts.hiddenSize / opts.numAttnHeads + ropeConfig := ml.RopeConfig{ PositionIDs: inputPositions, RopeFactors: nil, - RopeDim: opts.ropeDim, + RopeDim: opts.ropeDimensions, RopeType: ml.RopeTypeNeoX, - OrigCtxLen: opts.ctxLen, - RopeBase: opts.ropeBase, - RopeScale: opts.ropeScale, + OrigCtxLen: opts.contextLength, + RopeBase: opts.ropeBaseFreq, + RopeScale: opts.ropeFreqScale, } - q := sa.Query.Forward(ctx, hiddenState) + // Project and reshape query states with rotary embeddings + queryStates := sa.Query.Forward(ctx, hiddenState) + queryStates = queryStates.Reshape(ctx, headDimension, opts.numAttnHeads, batchSize) + queryStates = queryStates.RoPE(ctx, ropeConfig) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, rc) + // Project and reshape key states with rotary embeddings + keyStates := sa.Key.Forward(ctx, hiddenState) + keyStates = keyStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize) + keyStates = keyStates.RoPE(ctx, ropeConfig) - k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, rc) + // Project and reshape value states + valueStates := sa.Value.Forward(ctx, hiddenState) + valueStates = valueStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize) - v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + // Update and retrieve from KV cache + cache.Put(ctx, keyStates, valueStates) + keyStates, valueStates, attentionMask := cache.Get(ctx) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) + // Prepare tensors for attention computation + queryStates = queryStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + keyStates = keyStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + valueStates = valueStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + // Apply scaling and attention mask to scores + attentionScores := keyStates.MulmatFullPrec(ctx, queryStates) + attentionScores = attentionScores.Scale(ctx, 1.0/math.Sqrt(float64(headDimension))) + attentionScores = attentionScores.Add(ctx, attentionMask) + // Compute scaled dot-product attention + attentionProbs := attentionScores.Softmax(ctx) - kq := k.MulmatFullPrec(ctx, q) - kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) + // Apply attention weights and reshape + weightedStates := valueStates.Mulmat(ctx, attentionProbs) + weightedStates = weightedStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + weightedStates = weightedStates.Reshape(ctx, opts.hiddenSize, batchSize) - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) - - return sa.Output.Forward(ctx, kqv) + // Project to output dimension + return sa.Output.Forward(ctx, weightedStates) } +// MLP implements the feed-forward network component with SwiGLU activation type MLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` @@ -131,10 +149,16 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) - return mlp.Down.Forward(ctx, hiddenState) + // Apply SwiGLU activation gating + gateActivation := mlp.Gate.Forward(ctx, hiddenState).SILU(ctx) + upProjection := mlp.Up.Forward(ctx, hiddenState) + intermediateStates := gateActivation.Mul(ctx, upProjection) + + // Project back to hidden dimension + return mlp.Down.Forward(ctx, intermediateStates) } +// Layer represents a single transformer layer combining self-attention and feed-forward components type Layer struct { AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` SelfAttention *SelfAttention @@ -143,52 +167,54 @@ type Layer struct { } func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + // Self-attention branch with residual connection residual := hiddenState - hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + normalizedAttention := l.AttentionNorm.Forward(ctx, hiddenState, opts.modelEpsilon) + attentionOutput := l.SelfAttention.Forward(ctx, normalizedAttention, positionIDs, cache, opts) + hiddenState = attentionOutput.Add(ctx, residual) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) - - hiddenState = hiddenState.Add(ctx, residual) + // Feed-forward branch with residual connection residual = hiddenState - - hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) - - hiddenState = l.MLP.Forward(ctx, hiddenState, opts) - - output := hiddenState.Add(ctx, residual) + normalizedMLP := l.MLPNorm.Forward(ctx, hiddenState, opts.modelEpsilon) + mlpOutput := l.MLP.Forward(ctx, normalizedMLP, opts) + output := mlpOutput.Add(ctx, residual) return output } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { - inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) + // Convert input tokens and positions to tensors + inputTensor, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) + positionsTensor, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + // Initial token embedding + hiddenStates := m.TokenEmbedding.Forward(ctx, inputTensor) + // Process through transformer layers for i, layer := range m.Layers { m.Cache.SetLayer(i) - hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) + hiddenStates = layer.Forward(ctx, hiddenStates, positionsTensor, m.Cache, m.Options) } - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + // Final layer normalization and output projection + normalizedOutput := m.OutputNorm.Forward(ctx, hiddenStates, m.modelEpsilon) + logits := m.Output.Forward(ctx, normalizedOutput) - hiddenState = m.Output.Forward(ctx, hiddenState) - - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) + // Extract requested output token positions + outputsTensor, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + return logits.Rows(ctx, outputsTensor), nil } func init() {