cache is king
This commit is contained in:
parent
fad98fabab
commit
d231229122
@ -4,8 +4,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/cache/causal"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
@ -54,7 +53,10 @@ func New(c ml.Config) (model.Model, error) {
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
},
|
||||
}
|
||||
m.Cache = causal.NewCausalCache(m.Shift)
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@ -65,7 +67,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@ -124,7 +126,7 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@ -154,10 +156,11 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
ctx.Forward(hiddenState)
|
||||
|
||||
fmt.Printf("hidden state = %s\n", ml.Dump(hiddenState))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cacheType := i % 2
|
||||
m.Cache.SetLayer(i)
|
||||
wc := m.Cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user