Compare commits

...

3 Commits

Author SHA1 Message Date
Bruce MacDonald
159821594c
Update ml/backend.go 2025-04-02 09:46:19 -07:00
Bruce MacDonald
cbeb2aab4f Update backend.go 2025-04-01 15:08:59 -07:00
Bruce MacDonald
96df15edfc ml: structured rope config to allow specifying context len
This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters.

Key changes:
- Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling
- Add RopeType enum to formalize different RoPE implementation variants
- Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension
- Update RoPE method signature across all tensor interfaces and implementations
- Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure

This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods.
2025-04-01 14:03:48 -07:00
7 changed files with 153 additions and 57 deletions

View File

@ -462,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
panic("not implemented") panic("not implemented")
} }

View File

@ -118,6 +118,53 @@ type Context interface {
Layer(int) Context Layer(int) Context
} }
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
type RopeType int
// Available RoPE implementation types
const (
RopeTypeNormal RopeType = iota // Standard RoPE implementation
RopeTypeNeox // NeoX-style RoPE implementation
RopeTypeMRoPE // Multimodal RoPE implementation
RopeTypeVision // Vision-specific RoPE implementation
)
type YarnConfig struct {
YarnCtxTrain int // Context size used during training (for YaRN scaling)
YarnExtFactor float32 // Extension factor for YaRN
YarnAttnFactor float32 // Attention scaling factor for YaRN
YarnBetaFast float32 // Fast decay parameter for YaRN
YarnBetaSlow float32 // Slow decay parameter for YaRN
}
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Rope Extension)
func DefaultYarnConfig(nCtx int32) *YarnConfig {
return &YarnConfig{
YarnCtxTrain: int(nCtx),
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
}
// RoPEConfig holds configuration for Rotary Position Embedding
type RoPEConfig struct {
// Dim is the dimensionality for applying rotary embeddings
Dim uint32
// Type specifies the RoPE implementation variant
Type RopeType
// Base controls frequency decay for the embeddings
Base float32
// Scale allows scaling the effective context length
Scale float32
*YarnConfig
}
type Tensor interface { type Tensor interface {
Dim(n int) int Dim(n int) int
Stride(n int) int Stride(n int) int
@ -141,7 +188,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor

View File

@ -907,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
} }
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const ( const (
ropeTypeNorm C.int = 0 ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2 ropeTypeNeox C.int = 2
@ -914,7 +916,8 @@ const (
ropeTypeVision C.int = 24 ropeTypeVision C.int = 24
) )
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { // RoPE applies Rotary Position Embeddings to the tensor
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b} ropeFactors = &Tensor{b: t.b}
} }
@ -924,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
} }
if config.YarnConfig == nil {
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
}
// Map Go RopeType to C implementation constants
var ropeTypeC C.int
switch config.Type {
case ml.RopeTypeNormal:
ropeTypeC = ropeTypeNorm
case ml.RopeTypeNeox:
ropeTypeC = ropeTypeNeox
case ml.RopeTypeMRoPE:
ropeTypeC = ropeTypeMrope
case ml.RopeTypeVision:
ropeTypeC = ropeTypeVision
default:
ropeTypeC = ropeTypeNorm
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, ctx.(*Context).ctx,
C.int(ropeDim), dequant,
C.int(ropeType), positionIDs.(*Tensor).t,
131072, // YaRN n_ctx_train ropeFactors.(*Tensor).t,
C.float(ropeBase), C.int(config.Dim),
C.float(ropeScale), ropeTypeC,
0., // YaRN ext_factor C.int(config.YarnCtxTrain),
1., // YaRN attn_factor C.float(config.Base),
32., // YaRN beta_fast C.float(config.Scale),
1., // YaRN beta_slow C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
), ),
} }
} }

View File

@ -13,10 +13,11 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32 eps float32
attnLogitSoftcap float32 attnLogitSoftcap float32
finalLogitSoftcap float32 finalLogitSoftcap float32
largeModelScaling bool largeModelScaling bool
ropeConfig ml.RoPEConfig
} }
type Model struct { type Model struct {
@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
attnKeyLen: int(c.Uint("attention.key_length")), attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")), attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"), attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -78,11 +84,10 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
} }
type MLP struct { type MLP struct {

View File

@ -13,9 +13,11 @@ import (
type TextOptions struct { type TextOptions struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeScale float32 eps float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool largeModelScaling bool
ropeLocalConfig ml.RoPEConfig
ropeGlobalConfig ml.RoPEConfig
} }
type TextModel struct { type TextModel struct {
@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel {
), ),
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{ TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)), attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)), attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeLocalConfig: ml.RoPEConfig{
ropeScale: c.Float("rope.freq_scale", 1.0), Base: c.Float("rope.local.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
ropeGlobalConfig: ml.RoPEConfig{
Base: c.Float("rope.global.freq_base", 1000000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -86,17 +100,16 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase ropeConfig := opts.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase ropeConfig = opts.ropeGlobalConfig
} }
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase ropeConfig := m.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase ropeConfig = m.ropeGlobalConfig
} }
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil return key.RoPE(ctx, shift, nil, ropeConfig), nil
} }
type TextMLP struct { type TextMLP struct {

View File

@ -14,8 +14,8 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps float32
ropeDim uint32 ropeConfig ml.RoPEConfig
} }
type Model struct { type Model struct {
@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeConfig: ml.RoPEConfig{
ropeScale: c.Float("rope.freq_scale", 1), Base: c.Float("rope.freq_base"),
ropeDim: c.Uint("rope.dimension_count"), Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -76,15 +80,14 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
} }
type MLP struct { type MLP struct {

View File

@ -20,15 +20,14 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
} }
return key, nil return key, nil
@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type TextModelOptions struct { type TextModelOptions struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps float32
ropeDim uint32 ropeConfig ml.RoPEConfig
crossAttentionLayers []uint32 crossAttentionLayers []uint32
} }
@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"), crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
} }