ml: structured rope config to allow specifying context len

This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters.

Key changes:
- Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling
- Add RopeType enum to formalize different RoPE implementation variants
- Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension
- Update RoPE method signature across all tensor interfaces and implementations
- Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure

This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods.
This commit is contained in:
Bruce MacDonald 2025-04-01 14:03:48 -07:00
parent 0cefd46f23
commit 51ad65f831
7 changed files with 212 additions and 58 deletions

View File

@ -570,6 +570,64 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize offset /= t.elementSize

View File

@ -119,6 +119,53 @@ type Context interface {
Layer(int) Context Layer(int) Context
} }
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
type RopeType int
// Available RoPE implementation types
const (
RopeTypeNormal RopeType = iota // Standard RoPE implementation
RopeTypeNeox // NeoX-style RoPE implementation
RopeTypeMRoPE // Multi-scale RoPE implementation
RopeTypeVision // Vision-specific RoPE implementation
)
type YarnConfig struct {
YarnCtxTrain int // Context size used during training (for YaRN scaling)
YarnExtFactor float32 // Extension factor for YaRN
YarnAttnFactor float32 // Attention scaling factor for YaRN
YarnBetaFast float32 // Fast decay parameter for YaRN
YarnBetaSlow float32 // Slow decay parameter for YaRN
}
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network)
func DefaultYarnConfig(nCtx int32) *YarnConfig {
return &YarnConfig{
YarnCtxTrain: int(nCtx),
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
}
// RoPEConfig holds configuration for Rotary Position Embedding
type RoPEConfig struct {
// Dim is the dimensionality for applying rotary embeddings
Dim uint32
// Type specifies the RoPE implementation variant
Type RopeType
// Base controls frequency decay for the embeddings
Base float32
// Scale allows scaling the effective context length
Scale float32
*YarnConfig
}
type Tensor interface { type Tensor interface {
Dim(n int) int Dim(n int) int
Stride(n int) int Stride(n int) int
@ -144,8 +191,8 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor

View File

@ -1064,6 +1064,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
} }
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const ( const (
ropeTypeNorm C.int = 0 ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2 ropeTypeNeox C.int = 2
@ -1071,7 +1073,8 @@ const (
ropeTypeVision C.int = 24 ropeTypeVision C.int = 24
) )
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { // RoPE applies Rotary Position Embeddings to the tensor
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b} ropeFactors = &Tensor{b: t.b}
} }
@ -1081,19 +1084,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
} }
if config.YarnConfig == nil {
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
}
// Map Go RopeType to C implementation constants
var ropeTypeC C.int
switch config.Type {
case ml.RopeTypeNormal:
ropeTypeC = ropeTypeNorm
case ml.RopeTypeNeox:
ropeTypeC = ropeTypeNeox
case ml.RopeTypeMRoPE:
ropeTypeC = ropeTypeMrope
case ml.RopeTypeVision:
ropeTypeC = ropeTypeVision
default:
ropeTypeC = ropeTypeNorm
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, ctx.(*Context).ctx,
C.int(ropeDim), dequant,
C.int(ropeType), positionIDs.(*Tensor).t,
131072, // YaRN n_ctx_train ropeFactors.(*Tensor).t,
C.float(ropeBase), C.int(config.Dim),
C.float(ropeScale), ropeTypeC,
0., // YaRN ext_factor C.int(config.YarnCtxTrain),
1., // YaRN attn_factor C.float(config.Base),
32., // YaRN beta_fast C.float(config.Scale),
1., // YaRN beta_slow C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
), ),
} }
} }

View File

@ -14,10 +14,11 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32 eps float32
attnLogitSoftcap float32 attnLogitSoftcap float32
finalLogitSoftcap float32 finalLogitSoftcap float32
largeModelScaling bool largeModelScaling bool
ropeConfig ml.RoPEConfig
} }
type Model struct { type Model struct {
@ -55,10 +56,15 @@ func New(c fs.Config) (model.Model, error) {
attnKeyLen: int(c.Uint("attention.key_length")), attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")), attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"), attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -78,11 +84,10 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
} }
type MLP struct { type MLP struct {

View File

@ -14,9 +14,11 @@ import (
type TextConfig struct { type TextConfig struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeScale float32 eps float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool largeModelScaling bool
ropeLocalConfig ml.RoPEConfig
ropeGlobalConfig ml.RoPEConfig
} }
type TextModel struct { type TextModel struct {
@ -55,16 +57,28 @@ func newTextModel(c fs.Config) *TextModel {
}, },
), ),
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{ TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)), attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)), attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeLocalConfig: ml.RoPEConfig{
ropeScale: c.Float("rope.freq_scale", 1.0), Base: c.Float("rope.local.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
ropeGlobalConfig: ml.RoPEConfig{
Base: c.Float("rope.global.freq_base", 1000000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -86,17 +100,16 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase ropeConfig := opts.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase ropeConfig = opts.ropeGlobalConfig
} }
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase ropeConfig := m.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase ropeConfig = m.ropeGlobalConfig
} }
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil return key.RoPE(ctx, shift, nil, ropeConfig), nil
} }
type TextMLP struct { type TextMLP struct {

View File

@ -15,8 +15,8 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps float32
ropeDim uint32 ropeConfig ml.RoPEConfig
} }
type Model struct { type Model struct {
@ -55,9 +55,13 @@ func New(c fs.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeConfig: ml.RoPEConfig{
ropeScale: c.Float("rope.freq_scale", 1), Base: c.Float("rope.freq_base"),
ropeDim: c.Uint("rope.dimension_count"), Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -77,15 +81,14 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -98,7 +101,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
} }
type MLP struct { type MLP struct {

View File

@ -21,15 +21,14 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -44,7 +43,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
} }
return key, nil return key, nil
@ -199,8 +198,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type TextModelOptions struct { type TextModelOptions struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps float32
ropeDim uint32 ropeConfig ml.RoPEConfig
crossAttentionLayers []int32 crossAttentionLayers []int32
} }
@ -241,10 +240,14 @@ func newTextModel(c fs.Config) *TextModel {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
ropeScale: c.Float("rope.freq_scale", 1), ropeConfig: ml.RoPEConfig{
ropeDim: c.Uint("rope.dimension_count"), Base: c.Float("rope.freq_base"),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"), Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
} }