From 9d2a20a7638a7c4e10cb119c3c3b6bf6e470ca3e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 10 Mar 2025 13:00:09 -0700 Subject: [PATCH] use non-causal mask for inputs with images --- model/models/gemma3/model_text.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index f63c2ed7b..bf7f6b4c0 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor visionOutputs := multimodal[0].Multimodal.(ml.Tensor) offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) + + if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok { + causal.SetCausal(ctx, false) + defer causal.SetCausal(ctx, true) + } } for i, layer := range m.Layers {