ctxLen -> origCtxLen
This commit is contained in:
parent
eb086514da
commit
c259747acb
@ -10,10 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
origCtxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -46,7 +46,7 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ctxLen: int(c.Uint("context_length")),
|
origCtxLen: int(c.Uint("context_length")),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
@ -73,7 +73,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
RopeFactors: opts.RopeFactors,
|
RopeFactors: opts.RopeFactors,
|
||||||
RopeDim: opts.ropeDim,
|
RopeDim: opts.ropeDim,
|
||||||
RopeType: ml.RopeTypeStandard,
|
RopeType: ml.RopeTypeStandard,
|
||||||
OrigCtxLen: opts.ctxLen,
|
OrigCtxLen: opts.origCtxLen,
|
||||||
RopeBase: opts.ropeBase,
|
RopeBase: opts.ropeBase,
|
||||||
RopeScale: opts.ropeScale,
|
RopeScale: opts.ropeScale,
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
|
|||||||
RopeFactors: m.Options.RopeFactors,
|
RopeFactors: m.Options.RopeFactors,
|
||||||
RopeDim: m.Options.ropeDim,
|
RopeDim: m.Options.ropeDim,
|
||||||
RopeType: ml.RopeTypeStandard,
|
RopeType: ml.RopeTypeStandard,
|
||||||
OrigCtxLen: m.Options.ctxLen,
|
OrigCtxLen: m.Options.origCtxLen,
|
||||||
RopeBase: m.Options.ropeBase,
|
RopeBase: m.Options.ropeBase,
|
||||||
RopeScale: m.Options.ropeScale,
|
RopeScale: m.Options.ropeScale,
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user