duplicate input embeddings

This commit is contained in:
Michael Yang 2025-04-29 10:09:44 -07:00 committed by Bruce MacDonald
parent 88b231f903
commit ff5d1a3dc0

View File

@ -146,7 +146,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)