Compare commits
2 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c259747acb | ||
![]() |
eb086514da |
@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
|
func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) {
|
|||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
|
||||||
|
type RopeType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
RopeTypeStandard RopeType = iota
|
||||||
|
_ // not yet used
|
||||||
|
RopeTypeNeoX
|
||||||
|
)
|
||||||
|
|
||||||
|
// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
|
||||||
|
type RopeConfig struct {
|
||||||
|
// PositionIDs contains the position indices for each token in the sequence
|
||||||
|
// These indices are used to calculate the rotary embeddings
|
||||||
|
PositionIDs Tensor
|
||||||
|
|
||||||
|
// RopeFactors is an optional tensor containing pre-computed rotation factors
|
||||||
|
RopeFactors Tensor
|
||||||
|
|
||||||
|
// RopeDim specifies the dimension size for the rotary embeddings
|
||||||
|
RopeDim uint32
|
||||||
|
|
||||||
|
// RopeType indicates which RoPE variant to use (e.g. normal or neox)
|
||||||
|
RopeType RopeType
|
||||||
|
|
||||||
|
// OrigCtxLen stores the original context length the model was trained with
|
||||||
|
OrigCtxLen int
|
||||||
|
|
||||||
|
// RopeBase is the base value used in the frequency calculation
|
||||||
|
RopeBase float32
|
||||||
|
|
||||||
|
// RopeScale is a scaling factor applied to position indices
|
||||||
|
RopeScale float32
|
||||||
|
|
||||||
|
// YaRN parameters can be added here if they need to be configurable
|
||||||
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||||
@ -75,7 +111,7 @@ type Tensor interface {
|
|||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
|
RoPE(ctx Context, rc RopeConfig) Tensor
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
|
@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
|
||||||
ropeTypeNorm C.int = iota
|
if rc.RopeFactors == nil {
|
||||||
)
|
rc.RopeFactors = &Tensor{}
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
|
||||||
if ropeFactors == nil {
|
|
||||||
ropeFactors = &Tensor{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dequant := t.t
|
dequant := t.t
|
||||||
@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
ctx.(*Context).ctx,
|
||||||
C.int(ropeDim),
|
dequant,
|
||||||
131072, // YaRN n_ctx_train
|
rc.PositionIDs.(*Tensor).t,
|
||||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
rc.RopeFactors.(*Tensor).t,
|
||||||
C.float(ropeBase),
|
C.int(rc.RopeDim),
|
||||||
C.float(ropeScale),
|
C.int(rc.RopeType),
|
||||||
|
C.int(rc.OrigCtxLen),
|
||||||
|
C.float(rc.RopeBase),
|
||||||
|
C.float(rc.RopeScale),
|
||||||
0., // YaRN ext_factor
|
0., // YaRN ext_factor
|
||||||
1., // YaRN attn_factor
|
1., // YaRN attn_factor
|
||||||
32., // YaRN beta_fast
|
32., // YaRN beta_fast
|
||||||
|
@ -10,10 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
hiddenSize, numHeads, numKVHeads int
|
origCtxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
origCtxLen: int(c.Uint("context_length")),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
@ -67,14 +68,23 @@ type SelfAttention struct {
|
|||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
rc := ml.RopeConfig{
|
||||||
|
PositionIDs: positionIDs,
|
||||||
|
RopeFactors: opts.RopeFactors,
|
||||||
|
RopeDim: opts.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: opts.origCtxLen,
|
||||||
|
RopeBase: opts.ropeBase,
|
||||||
|
RopeScale: opts.ropeScale,
|
||||||
|
}
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, rc)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, rc)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
return key.RoPE(
|
||||||
|
ctx,
|
||||||
|
ml.RopeConfig{
|
||||||
|
PositionIDs: shift,
|
||||||
|
RopeFactors: m.Options.RopeFactors,
|
||||||
|
RopeDim: m.Options.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: m.Options.origCtxLen,
|
||||||
|
RopeBase: m.Options.ropeBase,
|
||||||
|
RopeScale: m.Options.ropeScale,
|
||||||
|
},
|
||||||
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -19,14 +19,23 @@ type TextSelfAttention struct {
|
|||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
rc := ml.RopeConfig{
|
||||||
|
PositionIDs: positions,
|
||||||
|
RopeFactors: opts.RopeFactors,
|
||||||
|
RopeDim: opts.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: opts.ctxLen,
|
||||||
|
RopeBase: opts.ropeBase,
|
||||||
|
RopeScale: opts.ropeScale,
|
||||||
|
}
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, rc)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, rc)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(
|
||||||
|
ctx,
|
||||||
|
ml.RopeConfig{
|
||||||
|
PositionIDs: shift,
|
||||||
|
RopeFactors: m.RopeFactors,
|
||||||
|
RopeDim: m.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: m.ctxLen,
|
||||||
|
RopeBase: m.ropeBase,
|
||||||
|
RopeScale: m.ropeScale,
|
||||||
|
},
|
||||||
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
|||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
|
|
||||||
crossAttentionLayers []uint32
|
crossAttentionLayers []uint32
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user