standard repqFreq var names

This commit is contained in:
Bruce MacDonald 2025-02-20 09:48:03 -08:00
parent f93bd92027
commit 2c0300073f

View File

@ -16,7 +16,7 @@ type Options struct {
numAttnHeads int numAttnHeads int
numKVHeads int numKVHeads int
modelEpsilon float32 modelEpsilon float32
ropeBaseFreq float32 ropeFreqBase float32
ropeFreqScale float32 ropeFreqScale float32
ropeDimensions uint32 ropeDimensions uint32
} }
@ -52,7 +52,7 @@ func New(c ml.Config) (model.Model, error) {
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
modelEpsilon: c.Float("attention.layer_norm_rms_epsilon"), modelEpsilon: c.Float("attention.layer_norm_rms_epsilon"),
contextLength: int(c.Uint("context_length")), contextLength: int(c.Uint("context_length")),
ropeBaseFreq: c.Float("rope.freq_base"), ropeFreqBase: c.Float("rope.freq_base"),
ropeFreqScale: c.Float("rope.freq_scale", 1), ropeFreqScale: c.Float("rope.freq_scale", 1),
ropeDimensions: c.Uint("rope.dimension_count", 64), ropeDimensions: c.Uint("rope.dimension_count", 64),
}, },
@ -73,7 +73,7 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
RopeDim: m.Options.ropeDimensions, RopeDim: m.Options.ropeDimensions,
RopeType: ml.RopeTypeNeoX, RopeType: ml.RopeTypeNeoX,
OrigCtxLen: m.Options.contextLength, OrigCtxLen: m.Options.contextLength,
RopeBase: m.Options.ropeBaseFreq, RopeBase: m.Options.ropeFreqBase,
RopeScale: m.Options.ropeFreqScale, RopeScale: m.Options.ropeFreqScale,
}, },
), nil ), nil
@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.
RopeDim: opts.ropeDimensions, RopeDim: opts.ropeDimensions,
RopeType: ml.RopeTypeNeoX, RopeType: ml.RopeTypeNeoX,
OrigCtxLen: opts.contextLength, OrigCtxLen: opts.contextLength,
RopeBase: opts.ropeBaseFreq, RopeBase: opts.ropeFreqBase,
RopeScale: opts.ropeFreqScale, RopeScale: opts.ropeFreqScale,
} }