From 0e886595bf3d4ee33737f4b30154210b0df2d2df Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 7 Mar 2025 16:05:56 -0800 Subject: [PATCH] Fix tests and drift from main --- kvcache/causal_test.go | 4 ++++ model/models/gemma2/model.go | 2 +- model/models/mllama/model_text.go | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 431a79b53..ed23cad6a 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -499,6 +499,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor { + panic("not implemented") +} + func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { panic("not implemented") } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index a82d68d37..2b8597c42 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 40c9a9707..1cf30d89b 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -28,7 +28,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)