Compare commits

...

2 Commits

Author SHA1 Message Date
Bruce MacDonald
c259747acb ctxLen -> origCtxLen 2025-02-20 11:16:53 -08:00
Bruce MacDonald
eb086514da ml: let model specify rope configuration
Add support for model-specific RoPE configuration parameters by:

1. Creating a new `RopeConfig` struct to encapsulate all RoPE parameters
2. Adding `RopeType` enum to specify different RoPE variants (Standard/NeoX)
3. Extracting original context length from model config
4. Refactoring `RoPE()` interface to use the new config struct
5. Updating llama and mllama models to use new RoPE configuration

This change allows models to specify their RoPE implementation type and
original context length, which is important for proper position embedding
calculation and model compatibility.
2025-02-14 14:18:51 -08:00
5 changed files with 104 additions and 28 deletions

View File

@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
panic("not implemented")
}

View File

@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) {
return nil, fmt.Errorf("unsupported backend")
}
// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
type RopeType int
const (
RopeTypeStandard RopeType = iota
_ // not yet used
RopeTypeNeoX
)
// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
type RopeConfig struct {
// PositionIDs contains the position indices for each token in the sequence
// These indices are used to calculate the rotary embeddings
PositionIDs Tensor
// RopeFactors is an optional tensor containing pre-computed rotation factors
RopeFactors Tensor
// RopeDim specifies the dimension size for the rotary embeddings
RopeDim uint32
// RopeType indicates which RoPE variant to use (e.g. normal or neox)
RopeType RopeType
// OrigCtxLen stores the original context length the model was trained with
OrigCtxLen int
// RopeBase is the base value used in the frequency calculation
RopeBase float32
// RopeScale is a scaling factor applied to position indices
RopeScale float32
// YaRN parameters can be added here if they need to be configurable
}
type Context interface {
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
@ -75,7 +111,7 @@ type Tensor interface {
Scale(ctx Context, s float64) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, rc RopeConfig) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
const (
ropeTypeNorm C.int = iota
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
if rc.RopeFactors == nil {
rc.RopeFactors = &Tensor{}
}
dequant := t.t
@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
return &Tensor{
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.float(ropeBase),
C.float(ropeScale),
ctx.(*Context).ctx,
dequant,
rc.PositionIDs.(*Tensor).t,
rc.RopeFactors.(*Tensor).t,
C.int(rc.RopeDim),
C.int(rc.RopeType),
C.int(rc.OrigCtxLen),
C.float(rc.RopeBase),
C.float(rc.RopeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast

View File

@ -10,10 +10,10 @@ import (
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
origCtxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
origCtxLen: int(c.Uint("context_length")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
@ -67,14 +68,23 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positionIDs,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.origCtxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, rc)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, rc)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.Options.RopeFactors,
RopeDim: m.Options.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.Options.origCtxLen,
RopeBase: m.Options.ropeBase,
RopeScale: m.Options.ropeScale,
},
), nil
}
type MLP struct {

View File

@ -19,14 +19,23 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positions,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.ctxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, rc)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, rc)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.RopeFactors,
RopeDim: m.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.ctxLen,
RopeBase: m.ropeBase,
RopeScale: m.ropeScale,
},
), nil
}
type TextMLP struct {
@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
}