diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 5daec7669..796987088 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -570,78 +570,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return out } -func (t *testTensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { - panic("not implemented") -} -func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Exp(ctx ml.Context) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } -func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } -func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } -func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } - -func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { - panic("not implemented") -} - func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { offset /= t.elementSize diff --git a/ml/backend.go b/ml/backend.go index 0def04b85..51806b735 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -119,53 +119,6 @@ type Context interface { Layer(int) Context } -// RopeType represents different RoPE (Rotary Position Embedding) implementation types -type RopeType int - -// Available RoPE implementation types -const ( - RopeTypeNormal RopeType = iota // Standard RoPE implementation - RopeTypeNeox // NeoX-style RoPE implementation - RopeTypeMRoPE // Multi-scale RoPE implementation - RopeTypeVision // Vision-specific RoPE implementation -) - -type YarnConfig struct { - YarnCtxTrain int // Context size used during training (for YaRN scaling) - YarnExtFactor float32 // Extension factor for YaRN - YarnAttnFactor float32 // Attention scaling factor for YaRN - YarnBetaFast float32 // Fast decay parameter for YaRN - YarnBetaSlow float32 // Slow decay parameter for YaRN -} - -// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network) -func DefaultYarnConfig(nCtx int32) *YarnConfig { - return &YarnConfig{ - YarnCtxTrain: int(nCtx), - YarnExtFactor: 0.0, - YarnAttnFactor: 1.0, - YarnBetaFast: 32.0, - YarnBetaSlow: 1.0, - } -} - -// RoPEConfig holds configuration for Rotary Position Embedding -type RoPEConfig struct { - // Dim is the dimensionality for applying rotary embeddings - Dim uint32 - - // Type specifies the RoPE implementation variant - Type RopeType - - // Base controls frequency decay for the embeddings - Base float32 - - // Scale allows scaling the effective context length - Scale float32 - - *YarnConfig -} - type Tensor interface { Dim(n int) int Stride(n int) int @@ -178,8 +131,6 @@ type Tensor interface { Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor - // Div computes the element-wise division (t1 / t2) for all values in the tensor - Div(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor @@ -193,15 +144,14 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + // RoPEWithLen allows the caller to specify the rope default context length + RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor - RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor Sin(ctx Context) Tensor Cos(ctx Context) Tensor Tanh(ctx Context) Tensor - // Exp computes the element-wise exponential (e^t) for all values in the tensor - Exp(ctx Context) Tensor GELU(ctx Context) Tensor SILU(ctx Context) Tensor Sigmoid(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 735c716dd..a4faed4bf 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -860,13 +860,6 @@ func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { } } -func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), - } -} - func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -1024,13 +1017,6 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor { } } -func (t *Tensor) Exp(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_exp_inplace(ctx.(*Context).ctx, t.t), - } -} - func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { if len(shape) != 4 { panic("expected 4 dimensions") @@ -1078,8 +1064,19 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -// RoPE applies Rotary Position Embeddings to the tensor -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { +const ( + ropeTypeNorm C.int = 0 + ropeTypeNeox C.int = 2 + ropeTypeMrope C.int = 8 + ropeTypeVision C.int = 24 +) + +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { + defaultContextLen := uint32(131072) + return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale) +} + +func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -1089,10 +1086,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } - if config.YarnConfig == nil { - config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing - } - return &Tensor{ b: t.b, t: C.ggml_rope_ext( @@ -1100,15 +1093,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, - C.int(config.Dim), - ropeTypeToC(config.Type), - C.int(config.YarnCtxTrain), - C.float(config.Base), - C.float(config.Scale), - C.float(config.YarnExtFactor), - C.float(config.YarnAttnFactor), - C.float(config.YarnBetaFast), - C.float(config.YarnBetaSlow), + C.int(ropeDim), + C.int(ropeType), + C.int(defaultContextLen), // YaRN n_ctx_train + C.float(ropeBase), + C.float(ropeScale), + 0., // YaRN ext_factor + 1., // YaRN attn_factor + 32., // YaRN beta_fast + 1., // YaRN beta_slow ), } } @@ -1119,60 +1112,6 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), } } -func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor { - if ropeFactors == nil { - ropeFactors = &Tensor{b: t.b} - } - - dequant := t.t - if C.ggml_is_quantized(t.t._type) { - dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) - } - - return &Tensor{ - b: t.b, - t: C.ggml_rope_multi( - ctx.(*Context).ctx, - dequant, - positionIDs.(*Tensor).t, - ropeFactors.(*Tensor).t, - C.int(config.Dim), - (*C.int)(unsafe.Pointer(§ions[0])), - ropeTypeToC(config.Type), - C.int(config.YarnCtxTrain), - C.float(config.Base), - C.float(config.Scale), - C.float(config.YarnExtFactor), - C.float(config.YarnAttnFactor), - C.float(config.YarnBetaFast), - C.float(config.YarnBetaSlow), - ), - } -} - -// GGML RoPE types -// These are the types used in the C implementation of RoPE -const ( - ropeTypeNorm C.int = 0 - ropeTypeNeox C.int = 2 - ropeTypeMrope C.int = 8 - ropeTypeVision C.int = 24 -) - -func ropeTypeToC(ropeType ml.RopeType) C.int { - switch ropeType { - case ml.RopeTypeNormal: - return ropeTypeNorm - case ml.RopeTypeNeox: - return ropeTypeNeox - case ml.RopeTypeMRoPE: - return ropeTypeMrope - case ml.RopeTypeVision: - return ropeTypeVision - default: - return ropeTypeNorm - } -} func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index fe0e59173..d418f6827 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -14,11 +14,10 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps float32 + eps, ropeBase, ropeScale float32 attnLogitSoftcap float32 finalLogitSoftcap float32 largeModelScaling bool - ropeConfig ml.RoPEConfig } type Model struct { @@ -56,15 +55,10 @@ func New(c fs.Config) (model.Model, error) { attnKeyLen: int(c.Uint("attention.key_length")), attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base", 10000.0), + ropeScale: c.Float("rope.freq_scale", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base", 10000.0), - Scale: c.Float("rope.freq_scale", 1.0), - Dim: c.Uint("attention.key_length"), - Type: ml.RopeTypeNormal, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, }, } @@ -84,10 +78,11 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) + ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) + q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -97,7 +92,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) + k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -127,7 +122,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, m.ropeConfig), nil + return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 218d84c02..c1e843d8f 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -14,11 +14,9 @@ import ( type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps float32 + eps, ropeScale float32 + ropeLocalBase, ropeGlobalBase float32 largeModelScaling bool - - ropeLocalConfig ml.RoPEConfig - ropeGlobalConfig ml.RoPEConfig } type TextModel struct { @@ -58,27 +56,15 @@ func newTextModel(c fs.Config) *TextModel { ), Layers: make([]TextLayer, numBlocks), TextConfig: &TextConfig{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - - ropeLocalConfig: ml.RoPEConfig{ - Base: c.Float("rope.local.freq_base", 10000.0), - Scale: c.Float("rope.freq_scale", 1.0), - Dim: c.Uint("attention.key_length", 256), - Type: ml.RopeTypeNeox, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, - ropeGlobalConfig: ml.RoPEConfig{ - Base: c.Float("rope.global.freq_base", 1000000.0), - Scale: c.Float("rope.freq_scale", 1.0), - Dim: c.Uint("attention.key_length", 256), - Type: ml.RopeTypeNeox, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), + ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), + ropeScale: c.Float("rope.freq_scale", 1.0), }, } @@ -100,16 +86,17 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) + ropeType := uint32(2) - ropeConfig := opts.ropeLocalConfig + ropeBase := opts.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeConfig = opts.ropeGlobalConfig + ropeBase = opts.ropeGlobalBase } q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, ropeConfig) + q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -120,7 +107,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, ropeConfig) + k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -133,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeConfig := m.ropeLocalConfig + ropeBase := m.TextConfig.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeConfig = m.ropeGlobalConfig + ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, ropeConfig), nil + return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index ec5da5b62..3e5a54278 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -15,8 +15,8 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int - eps float32 - ropeConfig ml.RoPEConfig + eps, ropeBase, ropeScale float32 + ropeDim uint32 } type Model struct { @@ -55,13 +55,9 @@ func New(c fs.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base"), - Scale: c.Float("rope.freq_scale", 1), - Dim: c.Uint("rope.dimension_count"), - Type: ml.RopeTypeNormal, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count"), }, } @@ -81,14 +77,15 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -101,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 46174cafa..3f9f578f1 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -31,28 +31,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = query.RoPE( - ctx, - positions, - sa.RopeFactors, - ml.RoPEConfig{ - Dim: uint32(opts.ropeDim), - Type: ml.RopeTypeNormal, - Base: opts.ropeBase, - Scale: opts.ropeScale, - }, - ) - key = key.RoPE( - ctx, - positions, - sa.RopeFactors, - ml.RoPEConfig{ - Dim: uint32(opts.ropeDim), - Type: ml.RopeTypeNormal, - Base: opts.ropeBase, - Scale: opts.ropeScale, - }, - ) + query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) } if opts.useQKNorm { @@ -275,15 +255,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE( - ctx, - shift, - m.Layers[layer].Attention.RopeFactors, - ml.RoPEConfig{ - Dim: uint32(m.TextOptions.ropeDim), - Type: ml.RopeTypeNormal, - Base: m.TextOptions.ropeBase, - Scale: m.TextOptions.ropeScale, - }, - ), nil + return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 250f13eee..1bf72acd8 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -17,7 +17,6 @@ type TextOptions struct { hiddenSize, numHeads, numKVHeads, headDim int eps, ropeBase, ropeScale float32 ropeDim uint32 - ropeConfig ml.RoPEConfig } type TextModel struct { @@ -41,6 +40,7 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) + ropeType := uint32(0) headDim := opts.headDim if headDim == 0 { headDim = opts.hiddenSize / opts.numHeads @@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil + return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -167,13 +167,9 @@ func NewTextModel(c fs.Config) (*TextModel, error) { numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base", 10000.0), - Scale: c.Float("rope.freq_scale", 1.0), - Dim: c.Uint("rope.dimension_count"), - Type: ml.RopeTypeNormal, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count"), }, } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 0ad300eba..490eb696c 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -21,14 +21,15 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -43,7 +44,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil } return key, nil @@ -198,8 +199,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int - eps float32 - ropeConfig ml.RoPEConfig + eps, ropeBase, ropeScale float32 + ropeDim uint32 crossAttentionLayers []int32 } @@ -240,14 +241,10 @@ func newTextModel(c fs.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base"), - Scale: c.Float("rope.freq_scale", 1), - Dim: c.Uint("rope.dimension_count"), - Type: ml.RopeTypeNormal, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), - }, }, } } diff --git a/model/models/qwen25vl/model_test.go b/model/models/qwen25vl/model_test.go deleted file mode 100644 index b9e590a90..000000000 --- a/model/models/qwen25vl/model_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package qwen25vl - -import ( - "testing" - - "github.com/ollama/ollama/ml/backend/ggml" - "github.com/ollama/ollama/model/input" -) - -func TestPostTokenize(t *testing.T) { - // Set up test inputs - model := &Model{} - mockHash := uint64(12345678) - - inputs := []input.Input{ - {Token: 123}, // Regular token - {Token: 456}, // Regular token - {Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token - {Token: 789}, // Regular token - } - - // Run the function being tested - result, err := model.PostTokenize(inputs) - if err != nil { - t.Fatalf("PostTokenize returned error: %v", err) - } - - // Verify the actual length first - expectedLength := 21 - if len(result) != expectedLength { - t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength) - } - - // Check key positions only - checkPositions := map[int]int32{ - 0: 123, // First regular token - 1: 456, // Second regular token - 2: 151652, // Vision start token - 4: 151655, // First placeholder token - 19: 151653, // Vision end token - 20: 789, // Final regular token - } - - for pos, expectedToken := range checkPositions { - if pos >= len(result) { - t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result)) - continue - } - if result[pos].Token != expectedToken { - t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token) - } - } - - // Check multimodal data is preserved - if result[3].MultimodalHash != mockHash { - t.Errorf("Multimodal hash not preserved: got %d, expected %d", - result[3].MultimodalHash, mockHash) - } -} diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 25c140ed4..c7d5dfc8b 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -13,8 +13,8 @@ import ( type TextOptions struct { ctxLen, hiddenSize, numHeads, numKVHeads int - eps float32 - ropeConfig ml.RoPEConfig + eps, ropeBase, ropeScale float32 + ropeDim, defaultContextLen uint32 } type TextModel struct { @@ -45,18 +45,15 @@ func NewTextModel(c fs.Config) *TextModel { ), Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ - ctxLen: int(c.Uint("context_length")), - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base"), - Scale: c.Float("rope.freq_scale", 1), - Dim: c.Uint("rope.dimension_count", 128), - Type: ml.RopeTypeNeox, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))), - }, + ctxLen: int(c.Uint("context_length")), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count", 128), + defaultContextLen: c.Uint("context_length", 128000), }, } @@ -79,11 +76,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -97,7 +94,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil + return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil } // MLP implements the feed-forward network component with SwiGLU activation