diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0b614df06..38db48e5e 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0 panic("not implemented") } -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor { +func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index 0e99ab5a8..e56a16c85 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) { return nil, fmt.Errorf("unsupported backend") } +// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend +type RopeType int + +const ( + RopeTypeStandard RopeType = iota + _ // not yet used + RopeTypeNeoX +) + +// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation +type RopeConfig struct { + // PositionIDs contains the position indices for each token in the sequence + // These indices are used to calculate the rotary embeddings + PositionIDs Tensor + + // RopeFactors is an optional tensor containing pre-computed rotation factors + RopeFactors Tensor + + // RopeDim specifies the dimension size for the rotary embeddings + RopeDim uint32 + + // RopeType indicates which RoPE variant to use (e.g. normal or neox) + RopeType RopeType + + // OrigCtxLen stores the original context length the model was trained with + OrigCtxLen int + + // RopeBase is the base value used in the frequency calculation + RopeBase float32 + + // RopeScale is a scaling factor applied to position indices + RopeScale float32 + + // YaRN parameters can be added here if they need to be configurable +} + type Context interface { Zeros(dtype DType, shape ...int) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) @@ -75,7 +111,7 @@ type Tensor interface { Scale(ctx Context, s float64) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor + RoPE(ctx Context, rc RopeConfig) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 6a727a60c..64a9e889d 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -const ( - ropeTypeNorm C.int = iota -) - -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { - if ropeFactors == nil { - ropeFactors = &Tensor{} +func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor { + if rc.RopeFactors == nil { + rc.RopeFactors = &Tensor{} } dequant := t.t @@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi return &Tensor{ t: C.ggml_rope_ext( - ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, - C.int(ropeDim), - 131072, // YaRN n_ctx_train - ropeTypeNorm, // ROPE_TYPE_NORM - C.float(ropeBase), - C.float(ropeScale), + ctx.(*Context).ctx, + dequant, + rc.PositionIDs.(*Tensor).t, + rc.RopeFactors.(*Tensor).t, + C.int(rc.RopeDim), + C.int(rc.RopeType), + C.int(rc.OrigCtxLen), + C.float(rc.RopeBase), + C.float(rc.RopeScale), 0., // YaRN ext_factor 1., // YaRN attn_factor 32., // YaRN beta_fast diff --git a/model/models/llama/model.go b/model/models/llama/model.go index b2c5c2c7b..1e4e44b72 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -10,10 +10,10 @@ import ( ) type Options struct { - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` - hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` + ctxLen, hiddenSize, numHeads, numKVHeads int + eps, ropeBase, ropeScale float32 + ropeDim uint32 } type Model struct { @@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), + ctxLen: int(c.Uint("context_length")), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count"), @@ -67,14 +68,23 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + rc := ml.RopeConfig{ + PositionIDs: positionIDs, + RopeFactors: opts.RopeFactors, + RopeDim: opts.ropeDim, + RopeType: ml.RopeTypeStandard, + OrigCtxLen: opts.ctxLen, + RopeBase: opts.ropeBase, + RopeScale: opts.ropeScale, + } q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, rc) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, rc) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE( + ctx, + ml.RopeConfig{ + PositionIDs: shift, + RopeFactors: m.Options.RopeFactors, + RopeDim: m.Options.ropeDim, + RopeType: ml.RopeTypeStandard, + OrigCtxLen: m.Options.ctxLen, + RopeBase: m.Options.ropeBase, + RopeScale: m.Options.ropeScale, + }, + ), nil } type MLP struct { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1e48086a3..d30c9a17e 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -19,14 +19,23 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + rc := ml.RopeConfig{ + PositionIDs: positions, + RopeFactors: opts.RopeFactors, + RopeDim: opts.ropeDim, + RopeType: ml.RopeTypeStandard, + OrigCtxLen: opts.ctxLen, + RopeBase: opts.ropeBase, + RopeScale: opts.ropeScale, + } query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, rc) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, rc) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers - return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE( + ctx, + ml.RopeConfig{ + PositionIDs: shift, + RopeFactors: m.RopeFactors, + RopeDim: m.ropeDim, + RopeType: ml.RopeTypeStandard, + OrigCtxLen: m.ctxLen, + RopeBase: m.ropeBase, + RopeScale: m.ropeScale, + }, + ), nil } type TextMLP struct { @@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr type TextModelOptions struct { RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` - hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + ctxLen, hiddenSize, numHeads, numKVHeads int + eps, ropeBase, ropeScale float32 + ropeDim uint32 crossAttentionLayers []uint32 }