simplify rope changes
This commit is contained in:
parent
698a92aa4a
commit
47705b5168
@ -570,78 +570,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Exp(ctx ml.Context) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
|
||||||
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
|
||||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
|
||||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
|
||||||
|
|
||||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
offset /= t.elementSize
|
offset /= t.elementSize
|
||||||
|
|
||||||
|
@ -119,53 +119,6 @@ type Context interface {
|
|||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
|
|
||||||
type RopeType int
|
|
||||||
|
|
||||||
// Available RoPE implementation types
|
|
||||||
const (
|
|
||||||
RopeTypeNormal RopeType = iota // Standard RoPE implementation
|
|
||||||
RopeTypeNeox // NeoX-style RoPE implementation
|
|
||||||
RopeTypeMRoPE // Multi-scale RoPE implementation
|
|
||||||
RopeTypeVision // Vision-specific RoPE implementation
|
|
||||||
)
|
|
||||||
|
|
||||||
type YarnConfig struct {
|
|
||||||
YarnCtxTrain int // Context size used during training (for YaRN scaling)
|
|
||||||
YarnExtFactor float32 // Extension factor for YaRN
|
|
||||||
YarnAttnFactor float32 // Attention scaling factor for YaRN
|
|
||||||
YarnBetaFast float32 // Fast decay parameter for YaRN
|
|
||||||
YarnBetaSlow float32 // Slow decay parameter for YaRN
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network)
|
|
||||||
func DefaultYarnConfig(nCtx int32) *YarnConfig {
|
|
||||||
return &YarnConfig{
|
|
||||||
YarnCtxTrain: int(nCtx),
|
|
||||||
YarnExtFactor: 0.0,
|
|
||||||
YarnAttnFactor: 1.0,
|
|
||||||
YarnBetaFast: 32.0,
|
|
||||||
YarnBetaSlow: 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoPEConfig holds configuration for Rotary Position Embedding
|
|
||||||
type RoPEConfig struct {
|
|
||||||
// Dim is the dimensionality for applying rotary embeddings
|
|
||||||
Dim uint32
|
|
||||||
|
|
||||||
// Type specifies the RoPE implementation variant
|
|
||||||
Type RopeType
|
|
||||||
|
|
||||||
// Base controls frequency decay for the embeddings
|
|
||||||
Base float32
|
|
||||||
|
|
||||||
// Scale allows scaling the effective context length
|
|
||||||
Scale float32
|
|
||||||
|
|
||||||
*YarnConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
Dim(n int) int
|
Dim(n int) int
|
||||||
Stride(n int) int
|
Stride(n int) int
|
||||||
@ -178,8 +131,6 @@ type Tensor interface {
|
|||||||
|
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
// Div computes the element-wise division (t1 / t2) for all values in the tensor
|
|
||||||
Div(ctx Context, t2 Tensor) Tensor
|
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||||
@ -193,15 +144,14 @@ type Tensor interface {
|
|||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||||
|
// RoPEWithLen allows the caller to specify the rope default context length
|
||||||
|
RoPEWithLen(ctx Context, positionIDs, ropeFactors Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) Tensor
|
||||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
|
||||||
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor
|
|
||||||
|
|
||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
Cos(ctx Context) Tensor
|
Cos(ctx Context) Tensor
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
// Exp computes the element-wise exponential (e^t) for all values in the tensor
|
|
||||||
Exp(ctx Context) Tensor
|
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
SILU(ctx Context) Tensor
|
SILU(ctx Context) Tensor
|
||||||
Sigmoid(ctx Context) Tensor
|
Sigmoid(ctx Context) Tensor
|
||||||
|
@ -860,13 +860,6 @@ func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@ -1024,13 +1017,6 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_exp_inplace(ctx.(*Context).ctx, t.t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
if len(shape) != 4 {
|
if len(shape) != 4 {
|
||||||
panic("expected 4 dimensions")
|
panic("expected 4 dimensions")
|
||||||
@ -1078,8 +1064,19 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoPE applies Rotary Position Embeddings to the tensor
|
const (
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
ropeTypeNorm C.int = 0
|
||||||
|
ropeTypeNeox C.int = 2
|
||||||
|
ropeTypeMrope C.int = 8
|
||||||
|
ropeTypeVision C.int = 24
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||||
|
defaultContextLen := uint32(131072)
|
||||||
|
return t.RoPEWithLen(ctx, positionIDs, ropeFactors, ropeDim, ropeType, defaultContextLen, ropeBase, ropeScale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) RoPEWithLen(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType, defaultContextLen uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||||
if ropeFactors == nil {
|
if ropeFactors == nil {
|
||||||
ropeFactors = &Tensor{b: t.b}
|
ropeFactors = &Tensor{b: t.b}
|
||||||
}
|
}
|
||||||
@ -1089,10 +1086,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
|
|||||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.YarnConfig == nil {
|
|
||||||
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
@ -1100,15 +1093,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
|
|||||||
dequant,
|
dequant,
|
||||||
positionIDs.(*Tensor).t,
|
positionIDs.(*Tensor).t,
|
||||||
ropeFactors.(*Tensor).t,
|
ropeFactors.(*Tensor).t,
|
||||||
C.int(config.Dim),
|
C.int(ropeDim),
|
||||||
ropeTypeToC(config.Type),
|
C.int(ropeType),
|
||||||
C.int(config.YarnCtxTrain),
|
C.int(defaultContextLen), // YaRN n_ctx_train
|
||||||
C.float(config.Base),
|
C.float(ropeBase),
|
||||||
C.float(config.Scale),
|
C.float(ropeScale),
|
||||||
C.float(config.YarnExtFactor),
|
0., // YaRN ext_factor
|
||||||
C.float(config.YarnAttnFactor),
|
1., // YaRN attn_factor
|
||||||
C.float(config.YarnBetaFast),
|
32., // YaRN beta_fast
|
||||||
C.float(config.YarnBetaSlow),
|
1., // YaRN beta_slow
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1119,60 +1112,6 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||||||
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
|
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
|
|
||||||
if ropeFactors == nil {
|
|
||||||
ropeFactors = &Tensor{b: t.b}
|
|
||||||
}
|
|
||||||
|
|
||||||
dequant := t.t
|
|
||||||
if C.ggml_is_quantized(t.t._type) {
|
|
||||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_rope_multi(
|
|
||||||
ctx.(*Context).ctx,
|
|
||||||
dequant,
|
|
||||||
positionIDs.(*Tensor).t,
|
|
||||||
ropeFactors.(*Tensor).t,
|
|
||||||
C.int(config.Dim),
|
|
||||||
(*C.int)(unsafe.Pointer(§ions[0])),
|
|
||||||
ropeTypeToC(config.Type),
|
|
||||||
C.int(config.YarnCtxTrain),
|
|
||||||
C.float(config.Base),
|
|
||||||
C.float(config.Scale),
|
|
||||||
C.float(config.YarnExtFactor),
|
|
||||||
C.float(config.YarnAttnFactor),
|
|
||||||
C.float(config.YarnBetaFast),
|
|
||||||
C.float(config.YarnBetaSlow),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GGML RoPE types
|
|
||||||
// These are the types used in the C implementation of RoPE
|
|
||||||
const (
|
|
||||||
ropeTypeNorm C.int = 0
|
|
||||||
ropeTypeNeox C.int = 2
|
|
||||||
ropeTypeMrope C.int = 8
|
|
||||||
ropeTypeVision C.int = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
func ropeTypeToC(ropeType ml.RopeType) C.int {
|
|
||||||
switch ropeType {
|
|
||||||
case ml.RopeTypeNormal:
|
|
||||||
return ropeTypeNorm
|
|
||||||
case ml.RopeTypeNeox:
|
|
||||||
return ropeTypeNeox
|
|
||||||
case ml.RopeTypeMRoPE:
|
|
||||||
return ropeTypeMrope
|
|
||||||
case ml.RopeTypeVision:
|
|
||||||
return ropeTypeVision
|
|
||||||
default:
|
|
||||||
return ropeTypeNorm
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
@ -14,11 +14,10 @@ import (
|
|||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps float32
|
eps, ropeBase, ropeScale float32
|
||||||
attnLogitSoftcap float32
|
attnLogitSoftcap float32
|
||||||
finalLogitSoftcap float32
|
finalLogitSoftcap float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
ropeConfig ml.RoPEConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -56,15 +55,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||||
attnValLen: int(c.Uint("attention.value_length")),
|
attnValLen: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||||
ropeConfig: ml.RoPEConfig{
|
|
||||||
Base: c.Float("rope.freq_base", 10000.0),
|
|
||||||
Scale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
Dim: c.Uint("attention.key_length"),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,10 +78,11 @@ type SelfAttention struct {
|
|||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
|
ropeType := uint32(2)
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@ -97,7 +92,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@ -127,7 +122,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
|
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -14,11 +14,9 @@ import (
|
|||||||
type TextConfig struct {
|
type TextConfig struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps float32
|
eps, ropeScale float32
|
||||||
|
ropeLocalBase, ropeGlobalBase float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
|
||||||
ropeLocalConfig ml.RoPEConfig
|
|
||||||
ropeGlobalConfig ml.RoPEConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@ -58,27 +56,15 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
),
|
),
|
||||||
Layers: make([]TextLayer, numBlocks),
|
Layers: make([]TextLayer, numBlocks),
|
||||||
TextConfig: &TextConfig{
|
TextConfig: &TextConfig{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeLocalConfig: ml.RoPEConfig{
|
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
Base: c.Float("rope.local.freq_base", 10000.0),
|
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||||
Scale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
Dim: c.Uint("attention.key_length", 256),
|
|
||||||
Type: ml.RopeTypeNeox,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
ropeGlobalConfig: ml.RoPEConfig{
|
|
||||||
Base: c.Float("rope.global.freq_base", 1000000.0),
|
|
||||||
Scale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
Dim: c.Uint("attention.key_length", 256),
|
|
||||||
Type: ml.RopeTypeNeox,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,16 +86,17 @@ type TextSelfAttention struct {
|
|||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
|
ropeType := uint32(2)
|
||||||
|
|
||||||
ropeConfig := opts.ropeLocalConfig
|
ropeBase := opts.ropeLocalBase
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeConfig = opts.ropeGlobalConfig
|
ropeBase = opts.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
|
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@ -120,7 +107,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
|
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@ -133,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeConfig := m.ropeLocalConfig
|
ropeBase := m.TextConfig.ropeLocalBase
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeConfig = m.ropeGlobalConfig
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, ropeConfig), nil
|
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
@ -15,8 +15,8 @@ import (
|
|||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeConfig ml.RoPEConfig
|
ropeDim uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -55,13 +55,9 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeConfig: ml.RoPEConfig{
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
Base: c.Float("rope.freq_base"),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
Scale: c.Float("rope.freq_scale", 1),
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
Dim: c.Uint("rope.dimension_count"),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,14 +77,15 @@ type SelfAttention struct {
|
|||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
ropeType := uint32(0)
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -101,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -31,28 +31,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
|||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
if useRope {
|
if useRope {
|
||||||
query = query.RoPE(
|
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
||||||
ctx,
|
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
||||||
positions,
|
|
||||||
sa.RopeFactors,
|
|
||||||
ml.RoPEConfig{
|
|
||||||
Dim: uint32(opts.ropeDim),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
Base: opts.ropeBase,
|
|
||||||
Scale: opts.ropeScale,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
key = key.RoPE(
|
|
||||||
ctx,
|
|
||||||
positions,
|
|
||||||
sa.RopeFactors,
|
|
||||||
ml.RoPEConfig{
|
|
||||||
Dim: uint32(opts.ropeDim),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
Base: opts.ropeBase,
|
|
||||||
Scale: opts.ropeScale,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.useQKNorm {
|
if opts.useQKNorm {
|
||||||
@ -275,15 +255,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(
|
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
|
||||||
ctx,
|
|
||||||
shift,
|
|
||||||
m.Layers[layer].Attention.RopeFactors,
|
|
||||||
ml.RoPEConfig{
|
|
||||||
Dim: uint32(m.TextOptions.ropeDim),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
Base: m.TextOptions.ropeBase,
|
|
||||||
Scale: m.TextOptions.ropeScale,
|
|
||||||
},
|
|
||||||
), nil
|
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,6 @@ type TextOptions struct {
|
|||||||
hiddenSize, numHeads, numKVHeads, headDim int
|
hiddenSize, numHeads, numKVHeads, headDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
ropeConfig ml.RoPEConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@ -41,6 +40,7 @@ type SelfAttention struct {
|
|||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
|
ropeType := uint32(0)
|
||||||
headDim := opts.headDim
|
headDim := opts.headDim
|
||||||
if headDim == 0 {
|
if headDim == 0 {
|
||||||
headDim = opts.hiddenSize / opts.numHeads
|
headDim = opts.hiddenSize / opts.numHeads
|
||||||
@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil
|
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
@ -167,13 +167,9 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
|
|||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
headDim: int(c.Uint("attention.key_length")),
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeConfig: ml.RoPEConfig{
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
Base: c.Float("rope.freq_base", 10000.0),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
Scale: c.Float("rope.freq_scale", 1.0),
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
Dim: c.Uint("rope.dimension_count"),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,14 +21,15 @@ type TextSelfAttention struct {
|
|||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
ropeType := uint32(0)
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -43,7 +44,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
|
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
@ -198,8 +199,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeConfig ml.RoPEConfig
|
ropeDim uint32
|
||||||
|
|
||||||
crossAttentionLayers []int32
|
crossAttentionLayers []int32
|
||||||
}
|
}
|
||||||
@ -240,14 +241,10 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
||||||
ropeConfig: ml.RoPEConfig{
|
|
||||||
Base: c.Float("rope.freq_base"),
|
|
||||||
Scale: c.Float("rope.freq_scale", 1),
|
|
||||||
Dim: c.Uint("rope.dimension_count"),
|
|
||||||
Type: ml.RopeTypeNormal,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
package qwen25vl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPostTokenize(t *testing.T) {
|
|
||||||
// Set up test inputs
|
|
||||||
model := &Model{}
|
|
||||||
mockHash := uint64(12345678)
|
|
||||||
|
|
||||||
inputs := []input.Input{
|
|
||||||
{Token: 123}, // Regular token
|
|
||||||
{Token: 456}, // Regular token
|
|
||||||
{Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token
|
|
||||||
{Token: 789}, // Regular token
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the function being tested
|
|
||||||
result, err := model.PostTokenize(inputs)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("PostTokenize returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the actual length first
|
|
||||||
expectedLength := 21
|
|
||||||
if len(result) != expectedLength {
|
|
||||||
t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check key positions only
|
|
||||||
checkPositions := map[int]int32{
|
|
||||||
0: 123, // First regular token
|
|
||||||
1: 456, // Second regular token
|
|
||||||
2: 151652, // Vision start token
|
|
||||||
4: 151655, // First placeholder token
|
|
||||||
19: 151653, // Vision end token
|
|
||||||
20: 789, // Final regular token
|
|
||||||
}
|
|
||||||
|
|
||||||
for pos, expectedToken := range checkPositions {
|
|
||||||
if pos >= len(result) {
|
|
||||||
t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if result[pos].Token != expectedToken {
|
|
||||||
t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check multimodal data is preserved
|
|
||||||
if result[3].MultimodalHash != mockHash {
|
|
||||||
t.Errorf("Multimodal hash not preserved: got %d, expected %d",
|
|
||||||
result[3].MultimodalHash, mockHash)
|
|
||||||
}
|
|
||||||
}
|
|
@ -13,8 +13,8 @@ import (
|
|||||||
|
|
||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeConfig ml.RoPEConfig
|
ropeDim, defaultContextLen uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@ -45,18 +45,15 @@ func NewTextModel(c fs.Config) *TextModel {
|
|||||||
),
|
),
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
TextOptions: &TextOptions{
|
TextOptions: &TextOptions{
|
||||||
ctxLen: int(c.Uint("context_length")),
|
ctxLen: int(c.Uint("context_length")),
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeConfig: ml.RoPEConfig{
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
Base: c.Float("rope.freq_base"),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
Scale: c.Float("rope.freq_scale", 1),
|
ropeDim: c.Uint("rope.dimension_count", 128),
|
||||||
Dim: c.Uint("rope.dimension_count", 128),
|
defaultContextLen: c.Uint("context_length", 128000),
|
||||||
Type: ml.RopeTypeNeox,
|
|
||||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,11 +76,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
q = q.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
k = k.RoPEWithLen(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.defaultContextLen, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -97,7 +94,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
return key.RoPEWithLen(ctx, shift, nil, m.ropeDim, 2, m.TextOptions.defaultContextLen, m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP implements the feed-forward network component with SwiGLU activation
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user