From 51ad65f831a9758016420cba0e1b56b13cb61de8 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 1 Apr 2025 14:03:48 -0700 Subject: [PATCH] ml: structured rope config to allow specifying context len This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters. Key changes: - Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling - Add RopeType enum to formalize different RoPE implementation variants - Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension - Update RoPE method signature across all tensor interfaces and implementations - Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods. --- kvcache/causal_test.go | 58 +++++++++++++++++++++++++++++++ ml/backend.go | 49 +++++++++++++++++++++++++- ml/backend/ggml/ggml.go | 47 +++++++++++++++++++------ model/models/gemma2/model.go | 19 ++++++---- model/models/gemma3/model_text.go | 53 +++++++++++++++++----------- model/models/llama/model.go | 21 ++++++----- model/models/mllama/model_text.go | 23 ++++++------ 7 files changed, 212 insertions(+), 58 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 796987088..dc4b81ecc 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -570,6 +570,64 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return out } +func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } + +func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { offset /= t.elementSize diff --git a/ml/backend.go b/ml/backend.go index ba24ecb45..8af79f069 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -119,6 +119,53 @@ type Context interface { Layer(int) Context } +// RopeType represents different RoPE (Rotary Position Embedding) implementation types +type RopeType int + +// Available RoPE implementation types +const ( + RopeTypeNormal RopeType = iota // Standard RoPE implementation + RopeTypeNeox // NeoX-style RoPE implementation + RopeTypeMRoPE // Multi-scale RoPE implementation + RopeTypeVision // Vision-specific RoPE implementation +) + +type YarnConfig struct { + YarnCtxTrain int // Context size used during training (for YaRN scaling) + YarnExtFactor float32 // Extension factor for YaRN + YarnAttnFactor float32 // Attention scaling factor for YaRN + YarnBetaFast float32 // Fast decay parameter for YaRN + YarnBetaSlow float32 // Slow decay parameter for YaRN +} + +// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network) +func DefaultYarnConfig(nCtx int32) *YarnConfig { + return &YarnConfig{ + YarnCtxTrain: int(nCtx), + YarnExtFactor: 0.0, + YarnAttnFactor: 1.0, + YarnBetaFast: 32.0, + YarnBetaSlow: 1.0, + } +} + +// RoPEConfig holds configuration for Rotary Position Embedding +type RoPEConfig struct { + // Dim is the dimensionality for applying rotary embeddings + Dim uint32 + + // Type specifies the RoPE implementation variant + Type RopeType + + // Base controls frequency decay for the embeddings + Base float32 + + // Scale allows scaling the effective context length + Scale float32 + + *YarnConfig +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -144,8 +191,8 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor Sin(ctx Context) Tensor Cos(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index e97795a69..5376d20ae 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1064,6 +1064,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } +// GGML RoPE types +// These are the types used in the C implementation of RoPE const ( ropeTypeNorm C.int = 0 ropeTypeNeox C.int = 2 @@ -1071,7 +1073,8 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { +// RoPE applies Rotary Position Embeddings to the tensor +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -1081,19 +1084,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } + if config.YarnConfig == nil { + config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing + } + + // Map Go RopeType to C implementation constants + var ropeTypeC C.int + switch config.Type { + case ml.RopeTypeNormal: + ropeTypeC = ropeTypeNorm + case ml.RopeTypeNeox: + ropeTypeC = ropeTypeNeox + case ml.RopeTypeMRoPE: + ropeTypeC = ropeTypeMrope + case ml.RopeTypeVision: + ropeTypeC = ropeTypeVision + default: + ropeTypeC = ropeTypeNorm + } + return &Tensor{ b: t.b, t: C.ggml_rope_ext( - ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, - C.int(ropeDim), - C.int(ropeType), - 131072, // YaRN n_ctx_train - C.float(ropeBase), - C.float(ropeScale), - 0., // YaRN ext_factor - 1., // YaRN attn_factor - 32., // YaRN beta_fast - 1., // YaRN beta_slow + ctx.(*Context).ctx, + dequant, + positionIDs.(*Tensor).t, + ropeFactors.(*Tensor).t, + C.int(config.Dim), + ropeTypeC, + C.int(config.YarnCtxTrain), + C.float(config.Base), + C.float(config.Scale), + C.float(config.YarnExtFactor), + C.float(config.YarnAttnFactor), + C.float(config.YarnBetaFast), + C.float(config.YarnBetaSlow), ), } } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index d418f6827..fe0e59173 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -14,10 +14,11 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeBase, ropeScale float32 + eps float32 attnLogitSoftcap float32 finalLogitSoftcap float32 largeModelScaling bool + ropeConfig ml.RoPEConfig } type Model struct { @@ -55,10 +56,15 @@ func New(c fs.Config) (model.Model, error) { attnKeyLen: int(c.Uint("attention.key_length")), attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -78,11 +84,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE(ctx, shift, nil, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index c1e843d8f..b05acbcad 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -14,9 +14,11 @@ import ( type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 + eps float32 largeModelScaling bool + + ropeLocalConfig ml.RoPEConfig + ropeGlobalConfig ml.RoPEConfig } type TextModel struct { @@ -55,16 +57,28 @@ func newTextModel(c fs.Config) *TextModel { }, ), Layers: make([]TextLayer, numBlocks), - TextConfig: &TextConfig{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + TextOptions: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + + ropeLocalConfig: ml.RoPEConfig{ + Base: c.Float("rope.local.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, + ropeGlobalConfig: ml.RoPEConfig{ + Base: c.Float("rope.global.freq_base", 1000000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -86,17 +100,16 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) - ropeBase := opts.ropeLocalBase + ropeConfig := opts.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase + ropeConfig = opts.ropeGlobalConfig } q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextConfig.ropeLocalBase + ropeConfig := m.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextConfig.ropeGlobalBase + ropeConfig = m.ropeGlobalConfig } - return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil + return key.RoPE(ctx, shift, nil, ropeConfig), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 3e5a54278..ec5da5b62 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -15,8 +15,8 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig } type Model struct { @@ -55,9 +55,13 @@ func New(c fs.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -77,15 +81,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -98,7 +101,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 490eb696c..dccb084d5 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -21,15 +21,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,7 +43,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil } return key, nil @@ -199,8 +198,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig crossAttentionLayers []int32 } @@ -241,10 +240,14 @@ func newTextModel(c fs.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), - crossAttentionLayers: c.Ints("attention.cross_attention_layers"), + crossAttentionLayers: c.Uints("attention.cross_attention_layers"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } }