cache is king
This commit is contained in:
parent
fad98fabab
commit
d231229122
@ -4,8 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/cache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/cache/causal"
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
@ -54,7 +53,10 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
m.Cache = causal.NewCausalCache(m.Shift)
|
|
||||||
|
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||||
|
m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift))
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,7 +67,7 @@ type SelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ type Layer struct {
|
|||||||
MLP *MLP
|
MLP *MLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
@ -154,10 +156,11 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
|||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
ctx.Forward(hiddenState)
|
ctx.Forward(hiddenState)
|
||||||
|
|
||||||
fmt.Printf("hidden state = %s\n", ml.Dump(hiddenState))
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
|
cacheType := i % 2
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
wc := m.Cache.(*kvcache.WrapperCache)
|
||||||
|
wc.SetLayerType(cacheType)
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user