use with pattern for rope
This commit is contained in:
parent
533f4c41bd
commit
2d2eb5903d
@ -119,6 +119,25 @@ type Context interface {
|
|||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RopeOpts contains optional parameters for RoPE function
|
||||||
|
type RopeOpts struct {
|
||||||
|
DefaultContextLen uint32
|
||||||
|
YarnExtFactor float32
|
||||||
|
YarnAttnFactor float32
|
||||||
|
YarnBetaFast float32
|
||||||
|
YarnBetaSlow float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// RopeOption defines a function that modifies RopeOpts
|
||||||
|
type RopeOption func(*RopeOpts)
|
||||||
|
|
||||||
|
// WithContextLen sets a custom context length
|
||||||
|
func WithContextLen(len uint32) RopeOption {
|
||||||
|
return func(opts *RopeOpts) {
|
||||||
|
opts.DefaultContextLen = len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
Dim(n int) int
|
Dim(n int) int
|
||||||
Stride(n int) int
|
Stride(n int) int
|
||||||
@ -144,9 +163,7 @@ type Tensor interface {
|
|||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
|
||||||
// RoPEWithLen allows the caller to specify the rope default context length
|
|
||||||
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
|
|
||||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
|
@ -1071,12 +1071,21 @@ const (
|
|||||||
ropeTypeVision C.int = 24
|
ropeTypeVision C.int = 24
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
|
||||||
defaultContextLen := uint32(131072)
|
// Default options
|
||||||
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale)
|
opts := &ml.RopeOpts{
|
||||||
}
|
DefaultContextLen: 131072,
|
||||||
|
YarnExtFactor: 0.0,
|
||||||
|
YarnAttnFactor: 1.0,
|
||||||
|
YarnBetaFast: 32.0,
|
||||||
|
YarnBetaSlow: 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply any provided options
|
||||||
|
for _, option := range options {
|
||||||
|
option(opts)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
|
|
||||||
if ropeFactors == nil {
|
if ropeFactors == nil {
|
||||||
ropeFactors = &Tensor{b: t.b}
|
ropeFactors = &Tensor{b: t.b}
|
||||||
}
|
}
|
||||||
@ -1095,13 +1104,13 @@ func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor,
|
|||||||
ropeFactors.(*Tensor).t,
|
ropeFactors.(*Tensor).t,
|
||||||
C.int(ropeDim),
|
C.int(ropeDim),
|
||||||
C.int(ropeType),
|
C.int(ropeType),
|
||||||
C.int(defaultContextLen), // YaRN n_ctx_train
|
C.int(opts.DefaultContextLen),
|
||||||
C.float(ropeBase),
|
C.float(ropeBase),
|
||||||
C.float(ropeScale),
|
C.float(ropeScale),
|
||||||
0., // YaRN ext_factor
|
C.float(opts.YarnExtFactor),
|
||||||
1., // YaRN attn_factor
|
C.float(opts.YarnAttnFactor),
|
||||||
32., // YaRN beta_fast
|
C.float(opts.YarnBetaFast),
|
||||||
1., // YaRN beta_slow
|
C.float(opts.YarnBetaSlow),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -96,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP implements the feed-forward network component with SwiGLU activation
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user