chunk vision outputs
This commit is contained in:
parent
661bf04696
commit
9ceee25d8b
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"image"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@ -71,7 +72,31 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||
return visionOutputs, nil
|
||||
return &chunks{Model: m, Tensor: visionOutputs}, nil
|
||||
}
|
||||
|
||||
type chunks struct {
|
||||
*Model
|
||||
ml.Tensor
|
||||
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
*chunks
|
||||
s, n int
|
||||
}
|
||||
|
||||
func (r *chunk) floats() []float32 {
|
||||
r.dataOnce.Do(func() {
|
||||
temp := r.Backend().NewContext()
|
||||
defer temp.Close()
|
||||
temp.Forward(r.Tensor).Compute(r.Tensor)
|
||||
r.data = r.Floats()
|
||||
})
|
||||
|
||||
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
@ -102,18 +127,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
// This is an image token with multimodal data
|
||||
visionOutputs := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
// Calculate tokens per grid based on grid dimensions
|
||||
chunksData := inp.Multimodal.(*chunks)
|
||||
patchesPerChunk := chunksData.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: visionStartToken, SameBatch: visionOutputs.Dim(1) + 2})
|
||||
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, input.Input{Token: imageToken, Multimodal: visionOutputs, MultimodalHash: inp.MultimodalHash})
|
||||
// Create a chunk with proper s and n values
|
||||
result = append(result, input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: patchesPerChunk,
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, visionOutputs.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
}
|
||||
|
@ -143,11 +143,16 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
for _, mi := range batch.Multimodal {
|
||||
f32s := mi.Multimodal.(*chunk).floats()
|
||||
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
}
|
||||
|
||||
// Process through transformer layers
|
||||
@ -159,9 +164,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user