Compare commits
3 Commits
main
...
jessegross
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0404eae3da | ||
![]() |
fafc332714 | ||
![]() |
7037dc9a47 |
@ -2,16 +2,30 @@ package input
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
|
||||
// Multimodal is a multimodal embedding or a component of one.
|
||||
// For example, it could be a row of an image that can be processed
|
||||
// independently.
|
||||
type Multimodal struct {
|
||||
// Tensor is the embedding data. Implementations may chose what to
|
||||
// store here or it may be nil if not needed. However, any ml.Tensor
|
||||
// objects must be stored here and not in Data.
|
||||
Tensor ml.Tensor
|
||||
|
||||
// Data is implementation-specific opaque data, such as metadata on how
|
||||
// to layout Tensor. It may be nil if not needed. It may also store larger
|
||||
// objects such as complete images if they are to be processed later.
|
||||
Data any
|
||||
}
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
Token int32
|
||||
|
||||
// Multimodal is opaque data representing a non-text
|
||||
// element such as an image (or part of one if the image
|
||||
// can be processed in pieces). It may be either together
|
||||
// with Token or on its own.
|
||||
Multimodal any
|
||||
// Multimodal is represents a non-text element such as an
|
||||
// image (or part of one if the image can be processed in pieces).
|
||||
// It may be used either together with Token or on its own.
|
||||
Multimodal []Multimodal
|
||||
|
||||
// MultimodalHash is a unique representation of the data
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
@ -32,7 +46,7 @@ type Input struct {
|
||||
// Positions slice.
|
||||
type MultimodalIndex struct {
|
||||
Index int
|
||||
Multimodal any
|
||||
Multimodal []Multimodal
|
||||
}
|
||||
|
||||
// Batch contains the inputs for a model forward pass
|
||||
|
@ -40,12 +40,13 @@ type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
// generates an output (typically an embedding) that can be used by the model.
|
||||
//
|
||||
// The return value is most typically an ml.Tensor, however, different
|
||||
// type are possible, such as an object containing a tensor plus
|
||||
// additional metadata, a slice of tensors or even just the original input.
|
||||
// The return value is one or more tensors, each with optional model-specific
|
||||
// opaque metadata. Typically, the tensors might be views into an embedding
|
||||
// with each view representing a chunk of data that can be processed independently
|
||||
// in different batches.
|
||||
//
|
||||
// The result may be cached by the runner.
|
||||
EncodeMultimodal(ml.Context, []byte) (any, error)
|
||||
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
|
||||
|
||||
// PostTokenize is called after tokenization to allow the model to edit the
|
||||
// input stream to correctly arrange multimodal elements.
|
||||
|
@ -82,7 +82,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
@ -108,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
return visionOutputs, nil
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
|
@ -178,7 +178,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
// set image embeddings
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"image"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@ -60,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) < 1 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
@ -100,69 +99,78 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
||||
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
|
||||
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
|
||||
|
||||
var multimodal []input.Multimodal
|
||||
aspectRatio := image.Point{ratioW, ratioH}
|
||||
|
||||
var offset int
|
||||
patchesPerChunk := projectedOutputs.Dim(1)
|
||||
if aspectRatio.Y*aspectRatio.X > 1 {
|
||||
patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1)
|
||||
|
||||
for range aspectRatio.Y {
|
||||
for x := range aspectRatio.X {
|
||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
||||
patchesPerChunk)
|
||||
var separator separator
|
||||
if x < aspectRatio.X-1 {
|
||||
separator.x = true // <|tile_x_separator|>
|
||||
} else {
|
||||
separator.y = true // <|tile_y_separator|>
|
||||
}
|
||||
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator})
|
||||
offset += patchesPerChunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
||||
patchesPerChunk)
|
||||
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
|
||||
|
||||
return multimodal, nil
|
||||
}
|
||||
|
||||
type chunks struct {
|
||||
*Model
|
||||
ml.Tensor
|
||||
aspectRatio image.Point
|
||||
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
*chunks
|
||||
s, n int
|
||||
}
|
||||
|
||||
func (r *chunk) floats() []float32 {
|
||||
r.dataOnce.Do(func() {
|
||||
temp := r.Backend().NewContext()
|
||||
defer temp.Close()
|
||||
temp.Forward(r.Tensor).Compute(r.Tensor)
|
||||
r.data = r.Floats()
|
||||
})
|
||||
|
||||
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
|
||||
type separator struct {
|
||||
x bool
|
||||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
t := inp.Multimodal.(*chunks)
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
var offset int
|
||||
patchesPerChunk := t.Dim(1)
|
||||
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
|
||||
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
|
||||
for range t.aspectRatio.Y {
|
||||
for x := range t.aspectRatio.X {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
if x < t.aspectRatio.X-1 {
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
offset += patchesPerChunk
|
||||
}
|
||||
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, imageInputs...)
|
||||
}
|
||||
|
@ -210,12 +210,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
for _, mi := range batch.Multimodal {
|
||||
f32s := mi.Multimodal.(*chunk).floats()
|
||||
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
img := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"image"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@ -88,7 +87,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
@ -112,37 +111,14 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||
|
||||
// split into patches to be sent to the text transformer
|
||||
parent := imageFeatures{tensor: features}
|
||||
rows := make([]*imageRow, size.Y)
|
||||
rows := make([]input.Multimodal, size.Y)
|
||||
for i := range rows {
|
||||
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
|
||||
rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X)
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
type imageFeatures struct {
|
||||
tensor ml.Tensor
|
||||
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
|
||||
type imageRow struct {
|
||||
parent *imageFeatures
|
||||
s int
|
||||
shape []int
|
||||
}
|
||||
|
||||
func (r *imageRow) data() []float32 {
|
||||
n := 1
|
||||
for _, s := range r.shape {
|
||||
n *= s
|
||||
}
|
||||
|
||||
return r.parent.data[r.s*n : (r.s+1)*n]
|
||||
}
|
||||
|
||||
// PostTokenize arranges Mistral 3's inputs for the forward pass
|
||||
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
@ -151,15 +127,14 @@ func (r *imageRow) data() []float32 {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal.([]*imageRow)
|
||||
for i, row := range inputMultimodal {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
|
||||
if i == len(inputMultimodal)-1 {
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, input.Input{Token: 13})
|
||||
} else {
|
||||
|
@ -110,20 +110,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
|
||||
// image embeddings
|
||||
for _, image := range batch.Multimodal {
|
||||
row := image.Multimodal.(*imageRow)
|
||||
row.parent.dataOnce.Do(func() {
|
||||
// use a new, throwaway context so the image tensor is not added to the graph
|
||||
temp := m.Backend().NewContext()
|
||||
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
|
||||
row.parent.data = row.parent.tensor.Floats()
|
||||
temp.Close()
|
||||
})
|
||||
|
||||
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
imageFeature := image.Multimodal[0].Tensor
|
||||
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
@ -95,7 +95,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
|
||||
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
return []input.Multimodal{{Tensor: m.Projector.Forward(ctx, crossAttentionStates)}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
@ -103,12 +103,12 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(inputs[i].Multimodal) == 0 {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
||||
inputs[i].Multimodal = append(inputs[i].Multimodal, images[j].Multimodal...)
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
@ -130,9 +130,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if len(batch.Multimodal) > 0 {
|
||||
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal
|
||||
if len(images) > 0 {
|
||||
crossAttentionStates = images[len(images)-1]
|
||||
crossAttentionStates = images[len(images)-1].Tensor
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,8 +104,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
|
||||
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||||
"used", numPast, "remaining", len(prompt)-numPast)
|
||||
|
||||
slot.Inputs = prompt[:numPast]
|
||||
prompt = prompt[numPast:]
|
||||
slot.Inputs = slot.Inputs[:numPast]
|
||||
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
@ -136,8 +136,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
||||
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||||
"used", numPast, "remaining", int32(len(prompt))-numPast)
|
||||
|
||||
slot.Inputs = prompt[:numPast]
|
||||
prompt = prompt[numPast:]
|
||||
slot.Inputs = slot.Inputs[:numPast]
|
||||
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package ollamarunner
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -12,10 +11,6 @@ import (
|
||||
)
|
||||
|
||||
func TestCountCommon(t *testing.T) {
|
||||
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
|
||||
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input.Input
|
||||
@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
|
||||
t1: []input.Input{{MultimodalHash: 1}},
|
||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
|
116
runner/ollamarunner/multimodal.go
Normal file
116
runner/ollamarunner/multimodal.go
Normal file
@ -0,0 +1,116 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// Tensors can't be used across multiple compute graphs. This is a problem
|
||||
// if a single embedding is split across batches using views since all of
|
||||
// the views will have the same source tensor. We also don't want to
|
||||
// recompute the entire embedding for each batch.
|
||||
//
|
||||
// To avoid this, we compute all of the tensors for the embedding on the
|
||||
// first use and then store the result in system memory. When we need
|
||||
// additional tensors, we recreate them from the stored data.
|
||||
|
||||
// multimodalEntry represents the embeddings of a single object (such
|
||||
// as an image).
|
||||
type multimodalEntry struct {
|
||||
// mm is the original set of tensors created by EncodeMultimodal
|
||||
mm []input.Multimodal
|
||||
|
||||
// data is the computed result of mm. Nil if not yet computed
|
||||
data [][]float32
|
||||
}
|
||||
|
||||
// multimodalStore maps from an individual tensor (of which there
|
||||
// may be many in a single multimodal object) to its parent embedding
|
||||
type multimodalStore map[ml.Tensor]*multimodalEntry
|
||||
|
||||
func newMultimodalStore() multimodalStore {
|
||||
return make(multimodalStore)
|
||||
}
|
||||
|
||||
// addMultimodal stores an embedding for later use in a compute graph
|
||||
func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
|
||||
entry := &multimodalEntry{mm: embedding}
|
||||
|
||||
for _, e := range embedding {
|
||||
if e.Tensor != nil {
|
||||
m[e.Tensor] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getMultimodal takes a source set of tensors (which may contain a whole or
|
||||
// parts of one or more images) and returns the equivalent that can be used in
|
||||
// the current context
|
||||
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
|
||||
out := make([]input.Multimodal, len(in))
|
||||
for i := range out {
|
||||
if in[i].Tensor != nil {
|
||||
var err error
|
||||
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out[i].Data = in[i].Data
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
|
||||
entry := m[in]
|
||||
|
||||
if entry.data == nil {
|
||||
computeCtx := backend.NewContext()
|
||||
defer computeCtx.Close()
|
||||
|
||||
var tensors []ml.Tensor
|
||||
for _, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
tensors = append(tensors, t.Tensor)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tensors) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
computeCtx.Forward(tensors...)
|
||||
entry.data = make([][]float32, len(entry.mm))
|
||||
|
||||
if !reserve {
|
||||
computeCtx.Compute(tensors...)
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
entry.data[i] = t.Tensor.Floats()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := computeCtx.Reserve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if in == t.Tensor {
|
||||
if !reserve {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
} else {
|
||||
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("multimodal tensor not found")
|
||||
}
|
@ -1,12 +1,14 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
@ -20,6 +22,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@ -40,6 +43,9 @@ type Sequence struct {
|
||||
// multimodal embeddings
|
||||
ctxs []ml.Context
|
||||
|
||||
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
|
||||
mmStore multimodalStore
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
|
||||
@ -101,7 +107,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, ctxs, err := s.inputs(prompt, images)
|
||||
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@ -156,6 +162,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@ -174,9 +181,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@ -187,6 +195,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
mmStore = newMultimodalStore()
|
||||
} else {
|
||||
parts = []string{prompt}
|
||||
}
|
||||
@ -196,7 +205,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
@ -216,7 +225,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
@ -224,13 +233,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
ctxs = append(ctxs, ctx)
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
s.multimodalHash.Reset()
|
||||
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
||||
imageHash := s.multimodalHash.Sum64()
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
@ -240,11 +251,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, ctxs, nil
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@ -363,6 +374,9 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
@ -433,7 +447,11 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
@ -459,9 +477,6 @@ func (s *Server) processBatch() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
@ -720,12 +735,71 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
// - Encode a 2048x2048 image. This assumes that a single image of this
|
||||
// size is sufficient to trigger the worst case. This is currently true
|
||||
// because for existing models, only a single image fits in a batch.
|
||||
// - Add the embedding to a full batch of tokens - this is necessary because
|
||||
// the model may be looking for non-image data, such as <image> tags.
|
||||
// - Run PostTokenize to execute any transformations between generated
|
||||
// embeddings and what the forward pass expects.
|
||||
// - The result may now be larger than a batch (images may not fit in a
|
||||
// single batch), so trim based on what will fit and must be grouped together.
|
||||
// - Fill out the rest of the space with text tokens.
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
||||
mmCtx := s.model.Backend().NewContext()
|
||||
defer mmCtx.Close()
|
||||
|
||||
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
|
||||
var buf bytes.Buffer
|
||||
bmp.Encode(&buf, img)
|
||||
|
||||
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
|
||||
mmStore.addMultimodal(inputs[0].Multimodal)
|
||||
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, inp := range inputs {
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > s.batchSize {
|
||||
inputs = inputs[i:min(i+minBatch, len(inputs))]
|
||||
break
|
||||
} else if i+minBatch > s.batchSize {
|
||||
inputs = inputs[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var batch input.Batch
|
||||
|
||||
inputs := make([]int32, s.batchSize)
|
||||
batchInputs := make([]int32, len(inputs))
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i := range inputs {
|
||||
for i, inp := range inputs {
|
||||
batchInputs[i] = inp.Token
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
|
||||
}
|
||||
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
@ -734,8 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user