From 9ceee25d8b795bd972019cef60107f5b320ea4f1 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 8 May 2025 16:31:27 -0700 Subject: [PATCH] chunk vision outputs --- model/models/qwen25vl/model.go | 44 ++++++++++++++++++++++++----- model/models/qwen25vl/model_text.go | 19 ++++++++----- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 31911fa98..3e3ff4d6c 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -5,6 +5,7 @@ import ( "fmt" "image" "slices" + "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -71,7 +72,31 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er } visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) - return visionOutputs, nil + return &chunks{Model: m, Tensor: visionOutputs}, nil +} + +type chunks struct { + *Model + ml.Tensor + + dataOnce sync.Once + data []float32 +} + +type chunk struct { + *chunks + s, n int +} + +func (r *chunk) floats() []float32 { + r.dataOnce.Do(func() { + temp := r.Backend().NewContext() + defer temp.Close() + temp.Forward(r.Tensor).Compute(r.Tensor) + r.data = r.Floats() + }) + + return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass @@ -102,18 +127,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } // This is an image token with multimodal data - visionOutputs := inp.Multimodal.(ml.Tensor) - - // Calculate tokens per grid based on grid dimensions + chunksData := inp.Multimodal.(*chunks) + patchesPerChunk := chunksData.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken, SameBatch: visionOutputs.Dim(1) + 2}) + result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2}) // Add the image token with the multimodal tensor data at the first position - result = append(result, input.Input{Token: imageToken, Multimodal: visionOutputs, MultimodalHash: inp.MultimodalHash}) + // Create a chunk with proper s and n values + result = append(result, input.Input{ + Token: imageToken, + Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk}, + MultimodalHash: inp.MultimodalHash, + SameBatch: patchesPerChunk, + }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) - result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, visionOutputs.Dim(1)-1)...) + result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) result = append(result, input.Input{Token: visionEndToken}) } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index c7d5dfc8b..735642e10 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -143,11 +143,16 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { // Initial token embedding - hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) + hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) - for _, image := range batch.Multimodal { - visionOutputs := image.Multimodal.(ml.Tensor) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + for _, mi := range batch.Multimodal { + f32s := mi.Multimodal.(*chunk).floats() + img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) + if err != nil { + panic(err) + } + + ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) } // Process through transformer layers @@ -159,9 +164,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor lastLayerOutputs = outputs } - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions) } - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return m.Output.Forward(ctx, hiddenState), nil + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil }