use with pattern for rope
This commit is contained in:
parent
533f4c41bd
commit
2d2eb5903d
@ -119,6 +119,25 @@ type Context interface {
|
||||
Layer(int) Context
|
||||
}
|
||||
|
||||
// RopeOpts contains optional parameters for RoPE function
|
||||
type RopeOpts struct {
|
||||
DefaultContextLen uint32
|
||||
YarnExtFactor float32
|
||||
YarnAttnFactor float32
|
||||
YarnBetaFast float32
|
||||
YarnBetaSlow float32
|
||||
}
|
||||
|
||||
// RopeOption defines a function that modifies RopeOpts
|
||||
type RopeOption func(*RopeOpts)
|
||||
|
||||
// WithContextLen sets a custom context length
|
||||
func WithContextLen(len uint32) RopeOption {
|
||||
return func(opts *RopeOpts) {
|
||||
opts.DefaultContextLen = len
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
@ -144,9 +163,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
// RoPEWithLen allows the caller to specify the rope default context length
|
||||
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
Sin(ctx Context) Tensor
|
||||
|
@ -1071,12 +1071,21 @@ const (
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
defaultContextLen := uint32(131072)
|
||||
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale)
|
||||
}
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
|
||||
// Default options
|
||||
opts := &ml.RopeOpts{
|
||||
DefaultContextLen: 131072,
|
||||
YarnExtFactor: 0.0,
|
||||
YarnAttnFactor: 1.0,
|
||||
YarnBetaFast: 32.0,
|
||||
YarnBetaSlow: 1.0,
|
||||
}
|
||||
|
||||
// Apply any provided options
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
@ -1095,13 +1104,13 @@ func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
C.int(ropeType),
|
||||
C.int(defaultContextLen), // YaRN n_ctx_train
|
||||
C.int(opts.DefaultContextLen),
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
1., // YaRN attn_factor
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
C.float(opts.YarnExtFactor),
|
||||
C.float(opts.YarnAttnFactor),
|
||||
C.float(opts.YarnBetaFast),
|
||||
C.float(opts.YarnBetaSlow),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -96,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
|
Loading…
x
Reference in New Issue
Block a user