connect vision to text
This commit is contained in:
parent
178761aef3
commit
470af8ab89
@ -3,6 +3,8 @@ package llama4
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"image"
|
"image"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
@ -78,7 +80,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ratioW, ratioH := int(size.X/m.imageSize), int(size.Y/m.imageSize)
|
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
|
||||||
|
|
||||||
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
|
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
|
||||||
@ -97,11 +99,75 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
||||||
return m.Projector.Forward(ctx, visionOutputs), nil
|
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
|
||||||
|
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunks struct {
|
||||||
|
*Model
|
||||||
|
ml.Tensor
|
||||||
|
aspectRatio image.Point
|
||||||
|
|
||||||
|
dataOnce sync.Once
|
||||||
|
data []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunk struct {
|
||||||
|
*chunks
|
||||||
|
s, n int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *chunk) floats() []float32 {
|
||||||
|
r.dataOnce.Do(func() {
|
||||||
|
temp := r.Backend().NewContext()
|
||||||
|
defer temp.Close()
|
||||||
|
temp.Forward(r.Tensor).Compute(r.Tensor)
|
||||||
|
r.data = r.Floats()
|
||||||
|
})
|
||||||
|
|
||||||
|
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
return inputs, nil
|
var result []input.Input
|
||||||
|
for _, inp := range inputs {
|
||||||
|
if inp.Multimodal == nil {
|
||||||
|
result = append(result, inp)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t := inp.Multimodal.(*chunks)
|
||||||
|
var imageInputs []input.Input
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||||
|
|
||||||
|
var offset int
|
||||||
|
patchesPerChunk := t.Dim(1)
|
||||||
|
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
|
||||||
|
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
|
||||||
|
|
||||||
|
for range t.aspectRatio.Y {
|
||||||
|
for x := range t.aspectRatio.X {
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
|
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
|
if x < t.aspectRatio.X-1 {
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||||
|
}
|
||||||
|
offset += patchesPerChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
|
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
|
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||||
|
|
||||||
|
result = append(result, imageInputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
@ -195,7 +195,17 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||||
|
|
||||||
|
for _, mi := range batch.Multimodal {
|
||||||
|
f32s := mi.Multimodal.(*chunk).floats()
|
||||||
|
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||||
|
}
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
cache.SetLayer(i)
|
cache.SetLayer(i)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user