use non-causal mask for inputs with images

This commit is contained in:
Michael Yang 2025-03-10 13:00:09 -07:00
parent 2e54d72fc3
commit 9d2a20a763

View File

@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, false)
defer causal.SetCausal(ctx, true)
}
}
for i, layer := range m.Layers {