Compare commits
3 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
159821594c | ||
![]() |
cbeb2aab4f | ||
![]() |
96df15edfc |
@ -462,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +118,53 @@ type Context interface {
|
|||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
|
||||||
|
type RopeType int
|
||||||
|
|
||||||
|
// Available RoPE implementation types
|
||||||
|
const (
|
||||||
|
RopeTypeNormal RopeType = iota // Standard RoPE implementation
|
||||||
|
RopeTypeNeox // NeoX-style RoPE implementation
|
||||||
|
RopeTypeMRoPE // Multimodal RoPE implementation
|
||||||
|
RopeTypeVision // Vision-specific RoPE implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
type YarnConfig struct {
|
||||||
|
YarnCtxTrain int // Context size used during training (for YaRN scaling)
|
||||||
|
YarnExtFactor float32 // Extension factor for YaRN
|
||||||
|
YarnAttnFactor float32 // Attention scaling factor for YaRN
|
||||||
|
YarnBetaFast float32 // Fast decay parameter for YaRN
|
||||||
|
YarnBetaSlow float32 // Slow decay parameter for YaRN
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Rope Extension)
|
||||||
|
func DefaultYarnConfig(nCtx int32) *YarnConfig {
|
||||||
|
return &YarnConfig{
|
||||||
|
YarnCtxTrain: int(nCtx),
|
||||||
|
YarnExtFactor: 0.0,
|
||||||
|
YarnAttnFactor: 1.0,
|
||||||
|
YarnBetaFast: 32.0,
|
||||||
|
YarnBetaSlow: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoPEConfig holds configuration for Rotary Position Embedding
|
||||||
|
type RoPEConfig struct {
|
||||||
|
// Dim is the dimensionality for applying rotary embeddings
|
||||||
|
Dim uint32
|
||||||
|
|
||||||
|
// Type specifies the RoPE implementation variant
|
||||||
|
Type RopeType
|
||||||
|
|
||||||
|
// Base controls frequency decay for the embeddings
|
||||||
|
Base float32
|
||||||
|
|
||||||
|
// Scale allows scaling the effective context length
|
||||||
|
Scale float32
|
||||||
|
|
||||||
|
*YarnConfig
|
||||||
|
}
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
Dim(n int) int
|
Dim(n int) int
|
||||||
Stride(n int) int
|
Stride(n int) int
|
||||||
@ -141,7 +188,7 @@ type Tensor interface {
|
|||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
|
@ -907,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GGML RoPE types
|
||||||
|
// These are the types used in the C implementation of RoPE
|
||||||
const (
|
const (
|
||||||
ropeTypeNorm C.int = 0
|
ropeTypeNorm C.int = 0
|
||||||
ropeTypeNeox C.int = 2
|
ropeTypeNeox C.int = 2
|
||||||
@ -914,7 +916,8 @@ const (
|
|||||||
ropeTypeVision C.int = 24
|
ropeTypeVision C.int = 24
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
// RoPE applies Rotary Position Embeddings to the tensor
|
||||||
|
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||||
if ropeFactors == nil {
|
if ropeFactors == nil {
|
||||||
ropeFactors = &Tensor{b: t.b}
|
ropeFactors = &Tensor{b: t.b}
|
||||||
}
|
}
|
||||||
@ -924,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.YarnConfig == nil {
|
||||||
|
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map Go RopeType to C implementation constants
|
||||||
|
var ropeTypeC C.int
|
||||||
|
switch config.Type {
|
||||||
|
case ml.RopeTypeNormal:
|
||||||
|
ropeTypeC = ropeTypeNorm
|
||||||
|
case ml.RopeTypeNeox:
|
||||||
|
ropeTypeC = ropeTypeNeox
|
||||||
|
case ml.RopeTypeMRoPE:
|
||||||
|
ropeTypeC = ropeTypeMrope
|
||||||
|
case ml.RopeTypeVision:
|
||||||
|
ropeTypeC = ropeTypeVision
|
||||||
|
default:
|
||||||
|
ropeTypeC = ropeTypeNorm
|
||||||
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
ctx.(*Context).ctx,
|
||||||
C.int(ropeDim),
|
dequant,
|
||||||
C.int(ropeType),
|
positionIDs.(*Tensor).t,
|
||||||
131072, // YaRN n_ctx_train
|
ropeFactors.(*Tensor).t,
|
||||||
C.float(ropeBase),
|
C.int(config.Dim),
|
||||||
C.float(ropeScale),
|
ropeTypeC,
|
||||||
0., // YaRN ext_factor
|
C.int(config.YarnCtxTrain),
|
||||||
1., // YaRN attn_factor
|
C.float(config.Base),
|
||||||
32., // YaRN beta_fast
|
C.float(config.Scale),
|
||||||
1., // YaRN beta_slow
|
C.float(config.YarnExtFactor),
|
||||||
|
C.float(config.YarnAttnFactor),
|
||||||
|
C.float(config.YarnBetaFast),
|
||||||
|
C.float(config.YarnBetaSlow),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,10 +13,11 @@ import (
|
|||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
attnLogitSoftcap float32
|
attnLogitSoftcap float32
|
||||||
finalLogitSoftcap float32
|
finalLogitSoftcap float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
ropeConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||||
attnValLen: int(c.Uint("attention.value_length")),
|
attnValLen: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||||
|
ropeConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.freq_base", 10000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,11 +84,10 @@ type SelfAttention struct {
|
|||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -13,9 +13,11 @@ import (
|
|||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeScale float32
|
eps float32
|
||||||
ropeLocalBase, ropeGlobalBase float32
|
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
|
||||||
|
ropeLocalConfig ml.RoPEConfig
|
||||||
|
ropeGlobalConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
),
|
),
|
||||||
Layers: make([]TextLayer, numBlocks),
|
Layers: make([]TextLayer, numBlocks),
|
||||||
TextOptions: &TextOptions{
|
TextOptions: &TextOptions{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeLocalConfig: ml.RoPEConfig{
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
Base: c.Float("rope.local.freq_base", 10000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length", 256),
|
||||||
|
Type: ml.RopeTypeNeox,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
|
ropeGlobalConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length", 256),
|
||||||
|
Type: ml.RopeTypeNeox,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,17 +100,16 @@ type TextSelfAttention struct {
|
|||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
ropeBase := opts.ropeLocalBase
|
ropeConfig := opts.ropeLocalConfig
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = opts.ropeGlobalBase
|
ropeConfig = opts.ropeGlobalConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeBase := m.TextOptions.ropeLocalBase
|
ropeConfig := m.ropeLocalConfig
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = m.TextOptions.ropeGlobalBase
|
ropeConfig = m.ropeGlobalConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
@ -14,8 +14,8 @@ import (
|
|||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
ropeDim uint32
|
ropeConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeConfig: ml.RoPEConfig{
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
Base: c.Float("rope.freq_base"),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
Scale: c.Float("rope.freq_scale", 1),
|
||||||
|
Dim: c.Uint("rope.dimension_count"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,15 +80,14 @@ type SelfAttention struct {
|
|||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -20,15 +20,14 @@ type TextSelfAttention struct {
|
|||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
ropeDim uint32
|
ropeConfig ml.RoPEConfig
|
||||||
|
|
||||||
crossAttentionLayers []uint32
|
crossAttentionLayers []uint32
|
||||||
}
|
}
|
||||||
@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
|
||||||
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
||||||
|
ropeConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.freq_base"),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1),
|
||||||
|
Dim: c.Uint("rope.dimension_count"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user