diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index f41665154..f8ba82110 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -4,8 +4,7 @@ import ( "fmt" "math" - "github.com/ollama/ollama/cache" - "github.com/ollama/ollama/cache/causal" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" @@ -54,7 +53,10 @@ func New(c ml.Config) (model.Model, error) { ropeScale: c.Float("rope.freq_scale", 1.0), }, } - m.Cache = causal.NewCausalCache(m.Shift) + + slidingWindowLen := int32(c.Uint("attention.sliding_window")) + m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift)) + return &m, nil } @@ -65,7 +67,7 @@ type SelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -124,7 +126,7 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -154,10 +156,11 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) ctx.Forward(hiddenState) - fmt.Printf("hidden state = %s\n", ml.Dump(hiddenState)) - for i, layer := range m.Layers { + cacheType := i % 2 m.Cache.SetLayer(i) + wc := m.Cache.(*kvcache.WrapperCache) + wc.SetLayerType(cacheType) hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) }