simplify rope changes

This commit is contained in:
Bruce MacDonald 2025-05-01 13:47:55 -07:00
parent 698a92aa4a
commit 47705b5168
11 changed files with 93 additions and 396 deletions

View File

@ -570,78 +570,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Exp(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize offset /= t.elementSize

View File

@ -119,53 +119,6 @@ type Context interface {
Layer(int) Context Layer(int) Context
} }
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
type RopeType int
// Available RoPE implementation types
const (
RopeTypeNormal RopeType = iota // Standard RoPE implementation
RopeTypeNeox // NeoX-style RoPE implementation
RopeTypeMRoPE // Multi-scale RoPE implementation
RopeTypeVision // Vision-specific RoPE implementation
)
type YarnConfig struct {
YarnCtxTrain int // Context size used during training (for YaRN scaling)
YarnExtFactor float32 // Extension factor for YaRN
YarnAttnFactor float32 // Attention scaling factor for YaRN
YarnBetaFast float32 // Fast decay parameter for YaRN
YarnBetaSlow float32 // Slow decay parameter for YaRN
}
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network)
func DefaultYarnConfig(nCtx int32) *YarnConfig {
return &YarnConfig{
YarnCtxTrain: int(nCtx),
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
}
// RoPEConfig holds configuration for Rotary Position Embedding
type RoPEConfig struct {
// Dim is the dimensionality for applying rotary embeddings
Dim uint32
// Type specifies the RoPE implementation variant
Type RopeType
// Base controls frequency decay for the embeddings
Base float32
// Scale allows scaling the effective context length
Scale float32
*YarnConfig
}
type Tensor interface { type Tensor interface {
Dim(n int) int Dim(n int) int
Stride(n int) int Stride(n int) int
@ -178,8 +131,6 @@ type Tensor interface {
Neg(ctx Context) Tensor Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
// Div computes the element-wise division (t1 / t2) for all values in the tensor
Div(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor
@ -193,15 +144,14 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
// RoPEWithLen allows the caller to specify the rope default context length
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
// Exp computes the element-wise exponential (e^t) for all values in the tensor
Exp(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor
SILU(ctx Context) Tensor SILU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor Sigmoid(ctx Context) Tensor

View File

@ -860,13 +860,6 @@ func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
@ -1024,13 +1017,6 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 { if len(shape) != 4 {
panic("expected 4 dimensions") panic("expected 4 dimensions")
@ -1078,8 +1064,19 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
} }
// RoPE applies Rotary Position Embeddings to the tensor const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
defaultContextLen := uint32(131072)
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale)
}
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b} ropeFactors = &Tensor{b: t.b}
} }
@ -1089,10 +1086,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
} }
if config.YarnConfig == nil {
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
@ -1100,15 +1093,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
dequant, dequant,
positionIDs.(*Tensor).t, positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(config.Dim), C.int(ropeDim),
ropeTypeToC(config.Type), C.int(ropeType),
C.int(config.YarnCtxTrain), C.int(defaultContextLen), // YaRN n_ctx_train
C.float(config.Base), C.float(ropeBase),
C.float(config.Scale), C.float(ropeScale),
C.float(config.YarnExtFactor), 0., // YaRN ext_factor
C.float(config.YarnAttnFactor), 1., // YaRN attn_factor
C.float(config.YarnBetaFast), 32., // YaRN beta_fast
C.float(config.YarnBetaSlow), 1., // YaRN beta_slow
), ),
} }
} }
@ -1119,60 +1112,6 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
} }
} }
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(config.Dim),
(*C.int)(unsafe.Pointer(&sections[0])),
ropeTypeToC(config.Type),
C.int(config.YarnCtxTrain),
C.float(config.Base),
C.float(config.Scale),
C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
),
}
}
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func ropeTypeToC(ropeType ml.RopeType) C.int {
switch ropeType {
case ml.RopeTypeNormal:
return ropeTypeNorm
case ml.RopeTypeNeox:
return ropeTypeNeox
case ml.RopeTypeMRoPE:
return ropeTypeMrope
case ml.RopeTypeVision:
return ropeTypeVision
default:
return ropeTypeNorm
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{

View File

@ -14,11 +14,10 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps float32 eps, ropeBase, ropeScale float32
attnLogitSoftcap float32 attnLogitSoftcap float32
finalLogitSoftcap float32 finalLogitSoftcap float32
largeModelScaling bool largeModelScaling bool
ropeConfig ml.RoPEConfig
} }
type Model struct { type Model struct {
@ -56,15 +55,10 @@ func New(c fs.Config) (model.Model, error) {
attnKeyLen: int(c.Uint("attention.key_length")), attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")), attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"), attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -84,10 +78,11 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -97,7 +92,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -127,7 +122,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
} }
type MLP struct { type MLP struct {

View File

@ -14,11 +14,9 @@ import (
type TextConfig struct { type TextConfig struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps float32 eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool largeModelScaling bool
ropeLocalConfig ml.RoPEConfig
ropeGlobalConfig ml.RoPEConfig
} }
type TextModel struct { type TextModel struct {
@ -58,27 +56,15 @@ func newTextModel(c fs.Config) *TextModel {
), ),
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{ TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)), attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)), attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeLocalConfig: ml.RoPEConfig{ ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
Base: c.Float("rope.local.freq_base", 10000.0), ropeScale: c.Float("rope.freq_scale", 1.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
ropeGlobalConfig: ml.RoPEConfig{
Base: c.Float("rope.global.freq_base", 1000000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -100,16 +86,17 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeConfig := opts.ropeLocalConfig ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeConfig = opts.ropeGlobalConfig ropeBase = opts.ropeGlobalBase
} }
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, ropeConfig) q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -120,7 +107,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, ropeConfig) k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -133,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeConfig := m.ropeLocalConfig ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeConfig = m.ropeGlobalConfig ropeBase = m.TextConfig.ropeGlobalBase
} }
return key.RoPE(ctx, shift, nil, ropeConfig), nil return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
} }
type TextMLP struct { type TextMLP struct {

View File

@ -15,8 +15,8 @@ import (
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps float32 eps, ropeBase, ropeScale float32
ropeConfig ml.RoPEConfig ropeDim uint32
} }
type Model struct { type Model struct {
@ -55,13 +55,9 @@ func New(c fs.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{ ropeBase: c.Float("rope.freq_base"),
Base: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1),
Scale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count"),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
@ -81,14 +77,15 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -101,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
type MLP struct { type MLP struct {

View File

@ -31,28 +31,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope { if useRope {
query = query.RoPE( query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
ctx, key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
positions,
sa.RopeFactors,
ml.RoPEConfig{
Dim: uint32(opts.ropeDim),
Type: ml.RopeTypeNormal,
Base: opts.ropeBase,
Scale: opts.ropeScale,
},
)
key = key.RoPE(
ctx,
positions,
sa.RopeFactors,
ml.RoPEConfig{
Dim: uint32(opts.ropeDim),
Type: ml.RopeTypeNormal,
Base: opts.ropeBase,
Scale: opts.ropeScale,
},
)
} }
if opts.useQKNorm { if opts.useQKNorm {
@ -275,15 +255,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE( return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
ctx,
shift,
m.Layers[layer].Attention.RopeFactors,
ml.RoPEConfig{
Dim: uint32(m.TextOptions.ropeDim),
Type: ml.RopeTypeNormal,
Base: m.TextOptions.ropeBase,
Scale: m.TextOptions.ropeScale,
},
), nil
} }

View File

@ -17,7 +17,6 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
ropeConfig ml.RoPEConfig
} }
type TextModel struct { type TextModel struct {
@ -41,6 +40,7 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim headDim := opts.headDim
if headDim == 0 { if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads headDim = opts.hiddenSize / opts.numHeads
@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
type MLP struct { type MLP struct {
@ -167,13 +167,9 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{ ropeBase: c.Float("rope.freq_base"),
Base: c.Float("rope.freq_base", 10000.0), ropeScale: c.Float("rope.freq_scale", 1),
Scale: c.Float("rope.freq_scale", 1.0), ropeDim: c.Uint("rope.dimension_count"),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }

View File

@ -21,14 +21,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -43,7 +44,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
} }
return key, nil return key, nil
@ -198,8 +199,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type TextModelOptions struct { type TextModelOptions struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps float32 eps, ropeBase, ropeScale float32
ropeConfig ml.RoPEConfig ropeDim uint32
crossAttentionLayers []int32 crossAttentionLayers []int32
} }
@ -240,14 +241,10 @@ func newTextModel(c fs.Config) *TextModel {
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }
} }

View File

@ -1,59 +0,0 @@
package qwen25vl
import (
"testing"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/model/input"
)
func TestPostTokenize(t *testing.T) {
// Set up test inputs
model := &Model{}
mockHash := uint64(12345678)
inputs := []input.Input{
{Token: 123}, // Regular token
{Token: 456}, // Regular token
{Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token
{Token: 789}, // Regular token
}
// Run the function being tested
result, err := model.PostTokenize(inputs)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
// Verify the actual length first
expectedLength := 21
if len(result) != expectedLength {
t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength)
}
// Check key positions only
checkPositions := map[int]int32{
0: 123, // First regular token
1: 456, // Second regular token
2: 151652, // Vision start token
4: 151655, // First placeholder token
19: 151653, // Vision end token
20: 789, // Final regular token
}
for pos, expectedToken := range checkPositions {
if pos >= len(result) {
t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result))
continue
}
if result[pos].Token != expectedToken {
t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token)
}
}
// Check multimodal data is preserved
if result[3].MultimodalHash != mockHash {
t.Errorf("Multimodal hash not preserved: got %d, expected %d",
result[3].MultimodalHash, mockHash)
}
}

View File

@ -13,8 +13,8 @@ import (
type TextOptions struct { type TextOptions struct {
ctxLen, hiddenSize, numHeads, numKVHeads int ctxLen, hiddenSize, numHeads, numKVHeads int
eps float32 eps, ropeBase, ropeScale float32
ropeConfig ml.RoPEConfig ropeDim, defaultContextLen uint32
} }
type TextModel struct { type TextModel struct {
@ -45,18 +45,15 @@ func NewTextModel(c fs.Config) *TextModel {
), ),
Layers: make([]Layer, c.Uint("block_count")), Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{ TextOptions: &TextOptions{
ctxLen: int(c.Uint("context_length")), ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{ ropeBase: c.Float("rope.freq_base"),
Base: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1),
Scale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count", 128),
Dim: c.Uint("rope.dimension_count", 128), defaultContextLen: c.Uint("context_length", 128000),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))),
},
}, },
} }
@ -79,11 +76,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -97,7 +94,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching // Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil
} }
// MLP implements the feed-forward network component with SwiGLU activation // MLP implements the feed-forward network component with SwiGLU activation