use with pattern for rope

This commit is contained in:
Bruce MacDonald 2025-05-12 14:14:03 -07:00
parent 533f4c41bd
commit 2d2eb5903d
3 changed files with 42 additions and 16 deletions

View File

@ -119,6 +119,25 @@ type Context interface {
Layer(int) Context
}
// RopeOpts contains optional parameters for RoPE function
type RopeOpts struct {
DefaultContextLen uint32
YarnExtFactor float32
YarnAttnFactor float32
YarnBetaFast float32
YarnBetaSlow float32
}
// RopeOption defines a function that modifies RopeOpts
type RopeOption func(*RopeOpts)
// WithContextLen sets a custom context length
func WithContextLen(len uint32) RopeOption {
return func(opts *RopeOpts) {
opts.DefaultContextLen = len
}
}
type Tensor interface {
Dim(n int) int
Stride(n int) int
@ -144,9 +163,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
// RoPEWithLen allows the caller to specify the rope default context length
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor

View File

@ -1071,12 +1071,21 @@ const (
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
defaultContextLen := uint32(131072)
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale)
}
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
// Default options
opts := &ml.RopeOpts{
DefaultContextLen: 131072,
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
// Apply any provided options
for _, option := range options {
option(opts)
}
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
@ -1095,13 +1104,13 @@ func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
C.int(defaultContextLen), // YaRN n_ctx_train
C.int(opts.DefaultContextLen),
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
C.float(opts.YarnExtFactor),
C.float(opts.YarnAttnFactor),
C.float(opts.YarnBetaFast),
C.float(opts.YarnBetaSlow),
),
}
}

View File

@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -96,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
}
// MLP implements the feed-forward network component with SwiGLU activation