diff --git a/model/input/input.go b/model/input/input.go index d66f52a0d..bd9b53ec6 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -2,16 +2,30 @@ package input import "github.com/ollama/ollama/ml" +// Multimodal is a multimodal embedding or a component of one. +// For example, it could be a row of an image that can be processed +// independently. +type Multimodal struct { + // Tensor is the embedding data. Implementations may chose what to + // store here or it may be nil if not needed. However, any ml.Tensor + // objects must be stored here and not in Data. + Tensor ml.Tensor + + // Data is implementation-specific opaque data, such as metadata on how + // to layout Tensor. It may be nil if not needed. It may also store larger + // objects such as complete images if they are to be processed later. + Data any +} + // Input represents one token in the input stream type Input struct { // Token is a single element of text. Token int32 - // Multimodal is opaque data representing a non-text - // element such as an image (or part of one if the image - // can be processed in pieces). It may be either together - // with Token or on its own. - Multimodal any + // Multimodal is represents a non-text element such as an + // image (or part of one if the image can be processed in pieces). + // It may be used either together with Token or on its own. + Multimodal []Multimodal // MultimodalHash is a unique representation of the data // stored in Multimodal, used for caching and comparing @@ -32,7 +46,7 @@ type Input struct { // Positions slice. type MultimodalIndex struct { Index int - Multimodal any + Multimodal []Multimodal } // Batch contains the inputs for a model forward pass diff --git a/model/model.go b/model/model.go index 7883b8517..98381c904 100644 --- a/model/model.go +++ b/model/model.go @@ -40,12 +40,13 @@ type MultimodalProcessor interface { // EncodeMultimodal processes a single input (such as an image) and // generates an output (typically an embedding) that can be used by the model. // - // The return value is most typically an ml.Tensor, however, different - // type are possible, such as an object containing a tensor plus - // additional metadata, a slice of tensors or even just the original input. + // The return value is one or more tensors, each with optional model-specific + // opaque metadata. Typically, the tensors might be views into an embedding + // with each view representing a chunk of data that can be processed independently + // in different batches. // // The result may be cached by the runner. - EncodeMultimodal(ml.Context, []byte) (any, error) + EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error) // PostTokenize is called after tokenization to allow the model to edit the // input stream to correctly arrange multimodal elements. diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index bf396b6a0..d53eb6ccc 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -82,7 +82,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -108,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) - return visionOutputs, nil + return []input.Multimodal{{Tensor: visionOutputs}}, nil } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.(ml.Tensor) + inputMultimodal := inp.Multimodal[0].Tensor result = append(result, - input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" - input.Input{Token: 255999}, // """ - input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder ) // add image token placeholders diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index c1e843d8f..7200089fd 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -178,7 +178,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // set image embeddings var except []int for _, image := range batch.Multimodal { - visionOutputs := image.Multimodal.(ml.Tensor) + visionOutputs := image.Multimodal[0].Tensor ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) for i := range visionOutputs.Dim(1) { diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 632d313ec..37646ee80 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -4,7 +4,6 @@ import ( "bytes" "image" "slices" - "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -60,7 +59,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) < 1 { return nil, model.ErrNoVisionModel } @@ -100,70 +99,79 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) projectedOutputs := m.Projector.Forward(ctx, visionOutputs) - return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil + + var multimodal []input.Multimodal + aspectRatio := image.Point{ratioW, ratioH} + + var offset int + patchesPerChunk := projectedOutputs.Dim(1) + if aspectRatio.Y*aspectRatio.X > 1 { + patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1) + + for range aspectRatio.Y { + for x := range aspectRatio.X { + view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, + projectedOutputs.Dim(0), projectedOutputs.Stride(1), + patchesPerChunk) + var separator separator + if x < aspectRatio.X-1 { + separator.x = true // <|tile_x_separator|> + } else { + separator.y = true // <|tile_y_separator|> + } + multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator}) + offset += patchesPerChunk + } + } + } + + view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, + projectedOutputs.Dim(0), projectedOutputs.Stride(1), + patchesPerChunk) + multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) + + return multimodal, nil } -type chunks struct { - *Model - ml.Tensor - aspectRatio image.Point - - dataOnce sync.Once - data []float32 -} - -type chunk struct { - *chunks - s, n int -} - -func (r *chunk) floats() []float32 { - r.dataOnce.Do(func() { - temp := r.Backend().NewContext() - defer temp.Close() - temp.Forward(r.Tensor).Compute(r.Tensor) - r.data = r.Floats() - }) - - return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] +type separator struct { + x bool + y bool } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) continue } - t := inp.Multimodal.(*chunks) var imageInputs []input.Input imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> - var offset int - patchesPerChunk := t.Dim(1) - if t.aspectRatio.Y*t.aspectRatio.X > 1 { - patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1) + for i, mm := range inp.Multimodal { + patchesPerChunk := mm.Tensor.Dim(1) - for range t.aspectRatio.Y { - for x := range t.aspectRatio.X { - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - if x < t.aspectRatio.X-1 { - imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> - } - offset += patchesPerChunk + if i < len(inp.Multimodal)-1 { + separator := mm.Data.(*separator) + + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + + if separator.x { + imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> } - - imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + if separator.y { + imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + } + } else { + imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> } } - imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> - result = append(result, imageInputs...) } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 3f9f578f1..d98587bd0 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -210,12 +210,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) for _, mi := range batch.Multimodal { - f32s := mi.Multimodal.(*chunk).floats() - img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) - if err != nil { - panic(err) - } - + img := mi.Multimodal[0].Tensor ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index f749fdcd2..5fc99fb22 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -4,7 +4,6 @@ import ( "bytes" "image" "slices" - "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -88,7 +87,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector { } } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -112,37 +111,14 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) // split into patches to be sent to the text transformer - parent := imageFeatures{tensor: features} - rows := make([]*imageRow, size.Y) + rows := make([]input.Multimodal, size.Y) for i := range rows { - rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} + rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X) } return rows, nil } -type imageFeatures struct { - tensor ml.Tensor - - dataOnce sync.Once - data []float32 -} - -type imageRow struct { - parent *imageFeatures - s int - shape []int -} - -func (r *imageRow) data() []float32 { - n := 1 - for _, s := range r.shape { - n *= s - } - - return r.parent.data[r.s*n : (r.s+1)*n] -} - // PostTokenize arranges Mistral 3's inputs for the forward pass // In Mistral 3 and Pixtral, the input patches are arranged as follows: // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] @@ -151,15 +127,14 @@ func (r *imageRow) data() []float32 { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.([]*imageRow) - for i, row := range inputMultimodal { + for i, row := range inp.Multimodal { // [IMG] - result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]}) - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) - if i == len(inputMultimodal)-1 { + result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) + if i == len(inp.Multimodal)-1 { // [IMG_END] result = append(result, input.Input{Token: 13}) } else { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 1bf72acd8..457e639db 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -110,20 +110,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // image embeddings for _, image := range batch.Multimodal { - row := image.Multimodal.(*imageRow) - row.parent.dataOnce.Do(func() { - // use a new, throwaway context so the image tensor is not added to the graph - temp := m.Backend().NewContext() - temp.Forward(row.parent.tensor).Compute(row.parent.tensor) - row.parent.data = row.parent.tensor.Floats() - temp.Close() - }) - - imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...) - if err != nil { - panic(err) - } - + imageFeature := image.Multimodal[0].Tensor ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 149876c9c..357051ccb 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -95,7 +95,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) - return m.Projector.Forward(ctx, crossAttentionStates), nil + return []input.Multimodal{{Tensor: m.Projector.Forward(ctx, crossAttentionStates)}}, nil } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { @@ -103,12 +103,12 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { fnvHash := fnv.New64a() for i := range inputs { - if inputs[i].Multimodal == nil { + if len(inputs[i].Multimodal) == 0 { if len(images) > 0 { - inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} + inputs[i].Multimodal = images[0].Multimodal inputs[i].MultimodalHash = images[0].MultimodalHash for j := 1; j < len(images); j++ { - inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) + inputs[i].Multimodal = append(inputs[i].Multimodal, images[j].Multimodal...) fnvHash.Reset() binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) @@ -130,9 +130,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(batch.Multimodal) > 0 { - images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor) + images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal if len(images) > 0 { - crossAttentionStates = images[len(images)-1] + crossAttentionStates = images[len(images)-1].Tensor } } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 062b654cf..6897b5e46 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -3,7 +3,6 @@ package ollamarunner import ( "errors" "fmt" - "image" "testing" "time" @@ -12,10 +11,6 @@ import ( ) func TestCountCommon(t *testing.T) { - imgA := image.NewRGBA(image.Rect(0, 0, 100, 100)) - imgB := image.NewRGBA(image.Rect(0, 0, 50, 50)) - imgC := image.NewRGBA(image.Rect(50, 50, 100, 100)) - tests := []struct { name string t1 []input.Input @@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) { }, { name: "Image Prefix", - t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, + t1: []input.Input{{MultimodalHash: 1}}, + t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, + t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, expected: 1, }, { diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go new file mode 100644 index 000000000..16d359219 --- /dev/null +++ b/runner/ollamarunner/multimodal.go @@ -0,0 +1,103 @@ +package ollamarunner + +import ( + "errors" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +// Tensors can't be used across multiple compute graphs. This is a problem +// if a single embedding is split across batches using views since all of +// the views will have the same source tensor. We also don't want to +// recompute the entire embedding for each batch. +// +// To avoid this, we compute all of the tensors for the embedding on the +// first use and then store the result in system memory. When we need +// additional tensors, we recreate them from the stored data. + +// multimodalEntry represents the embeddings of a single object (such +// as an image). +type multimodalEntry struct { + // mm is the original set of tensors created by EncodeMultimodal + mm []input.Multimodal + + // data is the computed result of mm. Nil if not yet computed + data [][]float32 +} + +// multimodalStore maps from an individual tensor (of which there +// may be many in a single multimodal object) to its parent embedding +type multimodalStore map[ml.Tensor]*multimodalEntry + +func newMultimodalStore() multimodalStore { + return make(multimodalStore) +} + +// addMultimodal stores an embedding for later use in a compute graph +func (m multimodalStore) addMultimodal(embedding []input.Multimodal) { + entry := &multimodalEntry{mm: embedding} + + for _, e := range embedding { + if e.Tensor != nil { + m[e.Tensor] = entry + } + } +} + +// getMultimodal takes a source set of tensors (which may contain a whole or +// parts of one or more images) and returns the equivalent that can be used in +// the current context +func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) { + out := make([]input.Multimodal, len(in)) + for i := range out { + if in[i].Tensor != nil { + var err error + out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor) + if err != nil { + return nil, err + } + } + + out[i].Data = in[i].Data + } + + return out, nil +} + +func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) { + entry := m[in] + + if entry.data == nil { + computeCtx := backend.NewContext() + defer computeCtx.Close() + + var tensors []ml.Tensor + for _, t := range entry.mm { + if t.Tensor != nil { + tensors = append(tensors, t.Tensor) + } + } + + if len(tensors) == 0 { + return nil, nil + } + + computeCtx.Forward(tensors...).Compute(tensors...) + + entry.data = make([][]float32, len(entry.mm)) + for i, t := range entry.mm { + if t.Tensor != nil { + entry.data[i] = t.Tensor.Floats() + } + } + } + + for i, t := range entry.mm { + if in == t.Tensor { + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + } + } + + return nil, errors.New("multimodal tensor not found") +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a5222236..4e203b7be 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -40,6 +40,9 @@ type Sequence struct { // multimodal embeddings ctxs []ml.Context + // mmStore holds multimodal embeddings to mange memory and enable splitting across batches + mmStore multimodalStore + // batch index iBatch int @@ -101,7 +104,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe startTime := time.Now() - inputs, ctxs, err := s.inputs(prompt, images) + inputs, ctxs, mmStore, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -156,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe return &Sequence{ ctxs: ctxs, + mmStore: mmStore, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -174,9 +178,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { var inputs []input.Input var ctxs []ml.Context + var mmStore multimodalStore var parts []string var matches [][]string @@ -187,6 +192,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ re := regexp.MustCompile(`\[img-(\d+)\]`) parts = re.Split(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1) + mmStore = newMultimodalStore() } else { parts = []string{prompt} } @@ -196,7 +202,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, nil, err + return nil, nil, nil, err } for _, t := range tokens { @@ -216,7 +222,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ } if imageIndex < 0 { - return nil, nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, nil, fmt.Errorf("invalid image index: %d", n) } ctx := s.model.Backend().NewContext() @@ -224,13 +230,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ctxs = append(ctxs, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, nil, err + return nil, nil, nil, err } s.multimodalHash.Reset() _, _ = s.multimodalHash.Write(images[imageIndex].Data) imageHash := s.multimodalHash.Sum64() + mmStore.addMultimodal(imageEmbeddings) + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } @@ -240,11 +248,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ var err error inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } - return inputs, ctxs, nil + return inputs, ctxs, mmStore, nil } type Server struct { @@ -363,6 +371,9 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() + ctx := s.model.Backend().NewContext() + defer ctx.Close() + var batchInputs []int32 var batch input.Batch @@ -433,7 +444,11 @@ func (s *Server) processBatch() error { batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) + mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal) + if err != nil { + return err + } + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) } batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) @@ -459,9 +474,6 @@ func (s *Server) processBatch() error { return nil } - ctx := s.model.Backend().NewContext() - defer ctx.Close() - modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err)