chunk vision outputs

This commit is contained in:
Bruce MacDonald 2025-05-08 16:31:27 -07:00
parent 661bf04696
commit 9ceee25d8b
2 changed files with 49 additions and 14 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"image" "image"
"slices" "slices"
"sync"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@ -71,7 +72,31 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
} }
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return visionOutputs, nil return &chunks{Model: m, Tensor: visionOutputs}, nil
}
type chunks struct {
*Model
ml.Tensor
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
} }
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
@ -102,18 +127,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
// This is an image token with multimodal data // This is an image token with multimodal data
visionOutputs := inp.Multimodal.(ml.Tensor) chunksData := inp.Multimodal.(*chunks)
patchesPerChunk := chunksData.Dim(1)
// Calculate tokens per grid based on grid dimensions
// First add the vision start token // First add the vision start token
result = append(result, input.Input{Token: visionStartToken, SameBatch: visionOutputs.Dim(1) + 2}) result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})
// Add the image token with the multimodal tensor data at the first position // Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{Token: imageToken, Multimodal: visionOutputs, MultimodalHash: inp.MultimodalHash}) // Create a chunk with proper s and n values
result = append(result, input.Input{
Token: imageToken,
Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
MultimodalHash: inp.MultimodalHash,
SameBatch: patchesPerChunk,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, visionOutputs.Dim(1)-1)...) result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken}) result = append(result, input.Input{Token: visionEndToken})
} }

View File

@ -143,11 +143,16 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding // Initial token embedding
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, image := range batch.Multimodal { for _, mi := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor) f32s := mi.Multimodal.(*chunk).floats()
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
} }
// Process through transformer layers // Process through transformer layers
@ -159,9 +164,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs lastLayerOutputs = outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenState), nil return m.Output.Forward(ctx, hiddenStates), nil
} }