From 470af8ab899aca6a72571f0c1e2ac6f9049aca29 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 17 Apr 2025 15:46:55 -0700 Subject: [PATCH] connect vision to text --- model/models/llama4/model.go | 72 +++++++++++++++++++++++++++++-- model/models/llama4/model_text.go | 12 +++++- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 9a1d06ebd..d3ed45ead 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -3,6 +3,8 @@ package llama4 import ( "bytes" "image" + "slices" + "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -78,7 +80,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - ratioW, ratioH := int(size.X/m.imageSize), int(size.Y/m.imageSize) + ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx) @@ -97,11 +99,75 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) - return m.Projector.Forward(ctx, visionOutputs), nil + projectedOutputs := m.Projector.Forward(ctx, visionOutputs) + return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil +} + +type chunks struct { + *Model + ml.Tensor + aspectRatio image.Point + + dataOnce sync.Once + data []float32 +} + +type chunk struct { + *chunks + s, n int +} + +func (r *chunk) floats() []float32 { + r.dataOnce.Do(func() { + temp := r.Backend().NewContext() + defer temp.Close() + temp.Forward(r.Tensor).Compute(r.Tensor) + r.data = r.Floats() + }) + + return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - return inputs, nil + var result []input.Input + for _, inp := range inputs { + if inp.Multimodal == nil { + result = append(result, inp) + continue + } + + t := inp.Multimodal.(*chunks) + var imageInputs []input.Input + imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> + + var offset int + patchesPerChunk := t.Dim(1) + if t.aspectRatio.Y*t.aspectRatio.X > 1 { + patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1) + + for range t.aspectRatio.Y { + for x := range t.aspectRatio.X { + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + if x < t.aspectRatio.X-1 { + imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> + } + offset += patchesPerChunk + } + + imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + } + } + + imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> + + result = append(result, imageInputs...) + } + + return result, nil } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 77e54e34a..c7ceceec5 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -195,7 +195,17 @@ func newTextModel(c fs.Config) *TextModel { } func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { - hiddenStates := m.TokenEmbedding.Forward(ctx, inputs) + hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) + + for _, mi := range batch.Multimodal { + f32s := mi.Multimodal.(*chunk).floats() + img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) + if err != nil { + panic(err) + } + + ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) + } for i, layer := range m.Layers { cache.SetLayer(i)