cache is king

This commit is contained in:
Patrick Devine 2025-02-07 16:17:19 -08:00
parent fad98fabab
commit d231229122

View File

@ -4,8 +4,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/cache"
"github.com/ollama/ollama/cache/causal"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
@ -54,7 +53,10 @@ func New(c ml.Config) (model.Model, error) {
ropeScale: c.Float("rope.freq_scale", 1.0),
},
}
m.Cache = causal.NewCausalCache(m.Shift)
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift))
return &m, nil
}
@ -65,7 +67,7 @@ type SelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
@ -124,7 +126,7 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@ -154,10 +156,11 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
ctx.Forward(hiddenState)
fmt.Printf("hidden state = %s\n", ml.Dump(hiddenState))
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}