diff --git a/ml/backend.go b/ml/backend.go index 51806b735..f68d6193d 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -119,6 +119,25 @@ type Context interface { Layer(int) Context } +// RopeOpts contains optional parameters for RoPE function +type RopeOpts struct { + DefaultContextLen uint32 + YarnExtFactor float32 + YarnAttnFactor float32 + YarnBetaFast float32 + YarnBetaSlow float32 +} + +// RopeOption defines a function that modifies RopeOpts +type RopeOption func(*RopeOpts) + +// WithContextLen sets a custom context length +func WithContextLen(len uint32) RopeOption { + return func(opts *RopeOpts) { + opts.DefaultContextLen = len + } +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -144,9 +163,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor - // RoPEWithLen allows the caller to specify the rope default context length - RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Sin(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index a4faed4bf..4015b14b8 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1071,12 +1071,21 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { - defaultContextLen := uint32(131072) - return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale) -} +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor { + // Default options + opts := &ml.RopeOpts{ + DefaultContextLen: 131072, + YarnExtFactor: 0.0, + YarnAttnFactor: 1.0, + YarnBetaFast: 32.0, + YarnBetaSlow: 1.0, + } + + // Apply any provided options + for _, option := range options { + option(opts) + } -func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -1095,13 +1104,13 @@ func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeFactors.(*Tensor).t, C.int(ropeDim), C.int(ropeType), - C.int(defaultContextLen), // YaRN n_ctx_train + C.int(opts.DefaultContextLen), C.float(ropeBase), C.float(ropeScale), - 0., // YaRN ext_factor - 1., // YaRN attn_factor - 32., // YaRN beta_fast - 1., // YaRN beta_slow + C.float(opts.YarnExtFactor), + C.float(opts.YarnAttnFactor), + C.float(opts.YarnBetaFast), + C.float(opts.YarnBetaSlow), ), } } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index d358b21a5..135616dba 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -96,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil } // MLP implements the feed-forward network component with SwiGLU activation