use with pattern for rope

This commit is contained in:
Bruce MacDonald 2025-05-12 14:14:03 -07:00
parent 533f4c41bd
commit 2d2eb5903d
3 changed files with 42 additions and 16 deletions

View File

@ -119,6 +119,25 @@ type Context interface {
Layer(int) Context Layer(int) Context
} }
// RopeOpts contains optional parameters for RoPE function
type RopeOpts struct {
DefaultContextLen uint32
YarnExtFactor float32
YarnAttnFactor float32
YarnBetaFast float32
YarnBetaSlow float32
}
// RopeOption defines a function that modifies RopeOpts
type RopeOption func(*RopeOpts)
// WithContextLen sets a custom context length
func WithContextLen(len uint32) RopeOption {
return func(opts *RopeOpts) {
opts.DefaultContextLen = len
}
}
type Tensor interface { type Tensor interface {
Dim(n int) int Dim(n int) int
Stride(n int) int Stride(n int) int
@ -144,9 +163,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
// RoPEWithLen allows the caller to specify the rope default context length
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor Sin(ctx Context) Tensor

View File

@ -1071,12 +1071,21 @@ const (
ropeTypeVision C.int = 24 ropeTypeVision C.int = 24
) )
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
defaultContextLen := uint32(131072) // Default options
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale) opts := &ml.RopeOpts{
} DefaultContextLen: 131072,
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
// Apply any provided options
for _, option := range options {
option(opts)
}
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b} ropeFactors = &Tensor{b: t.b}
} }
@ -1095,13 +1104,13 @@ func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor,
ropeFactors.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim), C.int(ropeDim),
C.int(ropeType), C.int(ropeType),
C.int(defaultContextLen), // YaRN n_ctx_train C.int(opts.DefaultContextLen),
C.float(ropeBase), C.float(ropeBase),
C.float(ropeScale), C.float(ropeScale),
0., // YaRN ext_factor C.float(opts.YarnExtFactor),
1., // YaRN attn_factor C.float(opts.YarnAttnFactor),
32., // YaRN beta_fast C.float(opts.YarnBetaFast),
1., // YaRN beta_slow C.float(opts.YarnBetaSlow),
), ),
} }
} }

View File

@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -96,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching // Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
} }
// MLP implements the feed-forward network component with SwiGLU activation // MLP implements the feed-forward network component with SwiGLU activation