Compare commits

...

3 Commits

Author SHA1 Message Date
Jesse Gross
0404eae3da ollamarunner: Multi-modal worst case graph
We currently preallocate compute graph memory for the worst case
batch of text tokens. This adds support for doing the same for
images.

Note that image models are more complicated than text models in
how they process their inputs so there may be cases where this
approach isn't completely generic for all models. It covers all
currently supported models though.
2025-05-12 16:35:02 -07:00
Jesse Gross
fafc332714 ollamarunner: Separate text and multimodal graphs
For some multimodal models (such as gemma3), we create a single
graph that generates the image embedding and then use this in the
text model. The embedding tensor is completely opaque to the runner.

However, this doesn't work if we need to use the embedding in multiple
batches. This can arise if the embedding is larger than the batch size.
In these cases (as with llama4), we would like to create views that
are more appropriately sized. However, if we do this then the original
source tensor is used in multiple graphs, which isn't allowed. To
avoid that problem, models with this pattern compute the embedding
tensor on first use and recreate the individual views. There is no
longer a single vision and text graph.

This codifies the pattern of separating vision and text graphs. The
logic of computing tensors on demand is moved to the runner, so models
no longer have to worry about this. It also gives the runner visibility
into the multimodal tensors, which is important for memory management.
2025-05-12 16:26:21 -07:00
Jesse Gross
7037dc9a47 ollamarunner: Base cached tokens on current prompt
When we restore a sequence from the cache, we split the prompt into
the already used tokens (stored in the cache) and new tokens that
need to be processed. Currently, the references to the used tokens
are coming from the stored previous sequence.

However, even though we know that the used tokens are semantically
equivalent to the prefix of the prompt, tokens can contain pointers
which are no longer valid. As a result, it is better to get the
used tokens from the prompt, which has currently valid pointers.

This doesn't currently have any impact because it isn't possible
to reuse the pointers (which are tensors) anyways. However, it
becomes an issue once we can.
2025-05-12 16:26:21 -07:00
14 changed files with 317 additions and 153 deletions

View File

@ -2,16 +2,30 @@ package input
import "github.com/ollama/ollama/ml" import "github.com/ollama/ollama/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream // Input represents one token in the input stream
type Input struct { type Input struct {
// Token is a single element of text. // Token is a single element of text.
Token int32 Token int32
// Multimodal is opaque data representing a non-text // Multimodal is represents a non-text element such as an
// element such as an image (or part of one if the image // image (or part of one if the image can be processed in pieces).
// can be processed in pieces). It may be either together // It may be used either together with Token or on its own.
// with Token or on its own. Multimodal []Multimodal
Multimodal any
// MultimodalHash is a unique representation of the data // MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing // stored in Multimodal, used for caching and comparing
@ -32,7 +46,7 @@ type Input struct {
// Positions slice. // Positions slice.
type MultimodalIndex struct { type MultimodalIndex struct {
Index int Index int
Multimodal any Multimodal []Multimodal
} }
// Batch contains the inputs for a model forward pass // Batch contains the inputs for a model forward pass

View File

@ -40,12 +40,13 @@ type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and // EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model. // generates an output (typically an embedding) that can be used by the model.
// //
// The return value is most typically an ml.Tensor, however, different // The return value is one or more tensors, each with optional model-specific
// type are possible, such as an object containing a tensor plus // opaque metadata. Typically, the tensors might be views into an embedding
// additional metadata, a slice of tensors or even just the original input. // with each view representing a chunk of data that can be processed independently
// in different batches.
// //
// The result may be cached by the runner. // The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error) EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the // PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements. // input stream to correctly arrange multimodal elements.

View File

@ -82,7 +82,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 { if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel return nil, model.ErrNoVisionModel
} }
@ -108,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return visionOutputs, nil return []input.Multimodal{{Tensor: visionOutputs}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input var result []input.Input
for _, inp := range inputs { for _, inp := range inputs {
if inp.Multimodal == nil { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
} else { } else {
inputMultimodal := inp.Multimodal.(ml.Tensor) inputMultimodal := inp.Multimodal[0].Tensor
result = append(result, result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>"" input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
) )
// add image token placeholders // add image token placeholders

View File

@ -178,7 +178,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// set image embeddings // set image embeddings
var except []int var except []int
for _, image := range batch.Multimodal { for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor) visionOutputs := image.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) { for i := range visionOutputs.Dim(1) {

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"image" "image"
"slices" "slices"
"sync"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@ -60,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) < 1 { if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel return nil, model.ErrNoVisionModel
} }
@ -100,70 +99,79 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs) projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
var multimodal []input.Multimodal
aspectRatio := image.Point{ratioW, ratioH}
var offset int
patchesPerChunk := projectedOutputs.Dim(1)
if aspectRatio.Y*aspectRatio.X > 1 {
patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1)
for range aspectRatio.Y {
for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
var separator separator
if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|>
} else {
separator.y = true // <|tile_y_separator|>
}
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator})
offset += patchesPerChunk
}
}
}
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil
} }
type chunks struct { type separator struct {
*Model x bool
ml.Tensor y bool
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input var result []input.Input
for _, inp := range inputs { for _, inp := range inputs {
if inp.Multimodal == nil { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
continue continue
} }
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var offset int for i, mm := range inp.Multimodal {
patchesPerChunk := t.Dim(1) patchesPerChunk := mm.Tensor.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
for range t.aspectRatio.Y { if i < len(inp.Multimodal)-1 {
for x := range t.aspectRatio.X { separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
if x < t.aspectRatio.X-1 { imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
} if separator.x {
offset += patchesPerChunk imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
} }
if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
} else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
} }
} }
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...) result = append(result, imageInputs...)
} }

View File

@ -210,12 +210,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal { for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats() img := mi.Multimodal[0].Tensor
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"image" "image"
"slices" "slices"
"sync"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@ -88,7 +87,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector {
} }
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 { if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel return nil, model.ErrNoVisionModel
} }
@ -112,37 +111,14 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer // split into patches to be sent to the text transformer
parent := imageFeatures{tensor: features} rows := make([]input.Multimodal, size.Y)
rows := make([]*imageRow, size.Y)
for i := range rows { for i := range rows {
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X)
} }
return rows, nil return rows, nil
} }
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass // PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows: // In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
@ -151,15 +127,14 @@ func (r *imageRow) data() []float32 {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input var result []input.Input
for _, inp := range inputs { for _, inp := range inputs {
if inp.Multimodal == nil { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
} else { } else {
inputMultimodal := inp.Multimodal.([]*imageRow) for i, row := range inp.Multimodal {
for i, row := range inputMultimodal {
// [IMG] // [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]}) result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inputMultimodal)-1 { if i == len(inp.Multimodal)-1 {
// [IMG_END] // [IMG_END]
result = append(result, input.Input{Token: 13}) result = append(result, input.Input{Token: 13})
} else { } else {

View File

@ -110,20 +110,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// image embeddings // image embeddings
for _, image := range batch.Multimodal { for _, image := range batch.Multimodal {
row := image.Multimodal.(*imageRow) imageFeature := image.Multimodal[0].Tensor
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
} }

View File

@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 { if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel return nil, model.ErrNoVisionModel
} }
@ -95,7 +95,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil return []input.Multimodal{{Tensor: m.Projector.Forward(ctx, crossAttentionStates)}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
@ -103,12 +103,12 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
fnvHash := fnv.New64a() fnvHash := fnv.New64a()
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal == nil { if len(inputs[i].Multimodal) == 0 {
if len(images) > 0 { if len(images) > 0 {
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ { for j := 1; j < len(images); j++ {
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) inputs[i].Multimodal = append(inputs[i].Multimodal, images[j].Multimodal...)
fnvHash.Reset() fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@ -130,9 +130,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor) images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal
if len(images) > 0 { if len(images) > 0 {
crossAttentionStates = images[len(images)-1] crossAttentionStates = images[len(images)-1].Tensor
} }
} }

View File

@ -104,8 +104,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt), slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", len(prompt)-numPast) "used", numPast, "remaining", len(prompt)-numPast)
slot.Inputs = prompt[:numPast]
prompt = prompt[numPast:] prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil return slot, prompt, nil
} }

View File

@ -136,8 +136,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt), slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", int32(len(prompt))-numPast) "used", numPast, "remaining", int32(len(prompt))-numPast)
slot.Inputs = prompt[:numPast]
prompt = prompt[numPast:] prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil return slot, prompt, nil
} }

View File

@ -3,7 +3,6 @@ package ollamarunner
import ( import (
"errors" "errors"
"fmt" "fmt"
"image"
"testing" "testing"
"time" "time"
@ -12,10 +11,6 @@ import (
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
tests := []struct { tests := []struct {
name string name string
t1 []input.Input t1 []input.Input
@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) {
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{ {
name: "Mixed, Same Length", name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1, expected: 1,
}, },
{ {

View File

@ -0,0 +1,116 @@
package ollamarunner
import (
"errors"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Tensors can't be used across multiple compute graphs. This is a problem
// if a single embedding is split across batches using views since all of
// the views will have the same source tensor. We also don't want to
// recompute the entire embedding for each batch.
//
// To avoid this, we compute all of the tensors for the embedding on the
// first use and then store the result in system memory. When we need
// additional tensors, we recreate them from the stored data.
// multimodalEntry represents the embeddings of a single object (such
// as an image).
type multimodalEntry struct {
// mm is the original set of tensors created by EncodeMultimodal
mm []input.Multimodal
// data is the computed result of mm. Nil if not yet computed
data [][]float32
}
// multimodalStore maps from an individual tensor (of which there
// may be many in a single multimodal object) to its parent embedding
type multimodalStore map[ml.Tensor]*multimodalEntry
func newMultimodalStore() multimodalStore {
return make(multimodalStore)
}
// addMultimodal stores an embedding for later use in a compute graph
func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
entry := &multimodalEntry{mm: embedding}
for _, e := range embedding {
if e.Tensor != nil {
m[e.Tensor] = entry
}
}
}
// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
out := make([]input.Multimodal, len(in))
for i := range out {
if in[i].Tensor != nil {
var err error
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
if err != nil {
return nil, err
}
}
out[i].Data = in[i].Data
}
return out, nil
}
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
entry := m[in]
if entry.data == nil {
computeCtx := backend.NewContext()
defer computeCtx.Close()
var tensors []ml.Tensor
for _, t := range entry.mm {
if t.Tensor != nil {
tensors = append(tensors, t.Tensor)
}
}
if len(tensors) == 0 {
return nil, nil
}
computeCtx.Forward(tensors...)
entry.data = make([][]float32, len(entry.mm))
if !reserve {
computeCtx.Compute(tensors...)
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
}
}
} else {
err := computeCtx.Reserve()
if err != nil {
return nil, err
}
}
}
for i, t := range entry.mm {
if in == t.Tensor {
if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
} else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
}
}
}
return nil, errors.New("multimodal tensor not found")
}

View File

@ -1,12 +1,14 @@
package ollamarunner package ollamarunner
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"hash/maphash" "hash/maphash"
"image"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@ -20,6 +22,7 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -40,6 +43,9 @@ type Sequence struct {
// multimodal embeddings // multimodal embeddings
ctxs []ml.Context ctxs []ml.Context
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
mmStore multimodalStore
// batch index // batch index
iBatch int iBatch int
@ -101,7 +107,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
startTime := time.Now() startTime := time.Now()
inputs, ctxs, err := s.inputs(prompt, images) inputs, ctxs, mmStore, err := s.inputs(prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
@ -156,6 +162,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
return &Sequence{ return &Sequence{
ctxs: ctxs, ctxs: ctxs,
mmStore: mmStore,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
@ -174,9 +181,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) { func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input var inputs []input.Input
var ctxs []ml.Context var ctxs []ml.Context
var mmStore multimodalStore
var parts []string var parts []string
var matches [][]string var matches [][]string
@ -187,6 +195,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
re := regexp.MustCompile(`\[img-(\d+)\]`) re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1) parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1)
mmStore = newMultimodalStore()
} else { } else {
parts = []string{prompt} parts = []string{prompt}
} }
@ -196,7 +205,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
for _, t := range tokens { for _, t := range tokens {
@ -216,7 +225,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
} }
if imageIndex < 0 { if imageIndex < 0 {
return nil, nil, fmt.Errorf("invalid image index: %d", n) return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
} }
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
@ -224,13 +233,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
ctxs = append(ctxs, ctx) ctxs = append(ctxs, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
s.multimodalHash.Reset() s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data) _, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64() imageHash := s.multimodalHash.Sum64()
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true postTokenize = true
} }
@ -240,11 +251,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
var err error var err error
inputs, err = multimodalProcessor.PostTokenize(inputs) inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
} }
return inputs, ctxs, nil return inputs, ctxs, mmStore, nil
} }
type Server struct { type Server struct {
@ -363,6 +374,9 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batchInputs []int32 var batchInputs []int32
var batch input.Batch var batch input.Batch
@ -433,7 +447,11 @@ func (s *Server) processBatch() error {
batchInputs = append(batchInputs, inp.Token) batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil { if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
} }
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
@ -459,9 +477,6 @@ func (s *Server) processBatch() error {
return nil return nil
} }
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) return fmt.Errorf("failed to decode batch: %w", err)
@ -720,12 +735,71 @@ func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
mmStore := newMultimodalStore()
// Multimodal strategy:
// - Encode a 2048x2048 image. This assumes that a single image of this
// size is sufficient to trigger the worst case. This is currently true
// because for existing models, only a single image fits in a batch.
// - Add the embedding to a full batch of tokens - this is necessary because
// the model may be looking for non-image data, such as <image> tags.
// - Run PostTokenize to execute any transformations between generated
// embeddings and what the forward pass expects.
// - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close()
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
var buf bytes.Buffer
bmp.Encode(&buf, img)
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
mmStore.addMultimodal(inputs[0].Multimodal)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return err
}
for i, inp := range inputs {
minBatch := 1 + inp.SameBatch
if minBatch > s.batchSize {
inputs = inputs[i:min(i+minBatch, len(inputs))]
break
} else if i+minBatch > s.batchSize {
inputs = inputs[:i]
break
}
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
copy(newInputs, inputs)
inputs = newInputs
}
}
}
var batch input.Batch var batch input.Batch
inputs := make([]int32, s.batchSize) batchInputs := make([]int32, len(inputs))
batch.Positions = make([]int32, len(inputs)) batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs)) batch.Sequences = make([]int, len(inputs))
for i := range inputs { for i, inp := range inputs {
batchInputs[i] = inp.Token
if inp.Multimodal != nil {
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
}
batch.Positions[i] = int32(i) batch.Positions[i] = int32(i)
} }
@ -734,8 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i) batch.Outputs[i] = int32(i)
} }
var err error batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil { if err != nil {
return err return err
} }