ollamarunner: Improve multimodal input handling

Various vision models have different requirements for how they
receive their inputs. For example:
 - Mllama wants images together with text and the image embeddings
   don't themselves have positions or get stored in the main KV cache
 - Llava-style models feed in embeddings similar to tokens and
   images correspond to a varying number of tokens in the cache.

In addition, the strategy for providing inputs must support batching
and multiple sequences, which are managed by the runner. At the same
time, we want to keep data handling fully in the model so that new
architectures are not bottlenecked by runner code which does not
understand their particular requirements.

This provides a method for models to edit the input stream so that
it meets their needs while still being in a format that the runner
understands. This allows the runner to avoid special processing
for different models.

In addition, this fixes a regression where non-vision models may
try to incorrectly interpret images.
This commit is contained in:
Jesse Gross 2025-03-05 12:08:06 -08:00 committed by Jesse Gross
parent b70fc4d51e
commit a7e63b82be
5 changed files with 247 additions and 130 deletions

View File

@ -3,7 +3,6 @@ package model
import (
"errors"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log/slog"
@ -22,14 +21,40 @@ import (
_ "github.com/ollama/ollama/ml/backend"
)
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Positions []int32
Sequences []int
Outputs []int32
Images []image.Image
Inputs []int32
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
}
type config struct {
@ -59,6 +84,37 @@ type Model interface {
Config() config
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []Input) ([]Input, error)
}
var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture

View File

@ -1,7 +1,12 @@
package mllama
import (
"bytes"
"encoding/binary"
"fmt"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@ -56,41 +61,79 @@ func New(c ml.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
var images []model.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if opts.Images != nil {
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
if err != nil {
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
if opts.Multimodal != nil {
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))

View File

@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"math"
"reflect"
"time"
"github.com/ollama/ollama/kvcache"
@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
slots[i] = InputCacheSlot{Id: i}
}
cache := model.Config().Cache
@ -83,7 +79,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []input
Inputs []model.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@ -92,7 +88,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest)
oldestSlot.Inputs = make([]model.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input, b []input) int32 {
func countCommonPrefix(a []model.Input, b []model.Input) int32 {
var count int32
for i := range a {
@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break
}
if !reflect.DeepEqual(a[i], b[i]) {
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break
}

View File

@ -4,6 +4,8 @@ import (
"image"
"testing"
"time"
"github.com/ollama/ollama/model"
)
func TestCountCommon(t *testing.T) {
@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []input
t2 []input
t1 []model.Input
t2 []model.Input
expected int32
}{
{
name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input{{token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []model.Input{{Token: 1}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []input{{image: imgA}},
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []input{{token: 1}, {image: imgA}},
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []model.Input{},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input{},
t2: []input{},
t1: []model.Input{},
t2: []model.Input{},
expected: 0,
},
}
@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input
prompt []model.Input
longest expected
best expected
}{
@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{},
Inputs: []model.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input{},
Inputs: []model.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 1}},
prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
Inputs: []model.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
Inputs: []model.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 2}},
prompt: []model.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
Inputs: []model.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input{{token: 1}},
prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
Inputs: []model.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 2}, {token: 3}},
prompt: []model.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}},
Inputs: []model.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},

View File

@ -1,13 +1,12 @@
package ollamarunner
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"image"
"hash/maphash"
"log"
"log/slog"
"net"
@ -33,22 +32,19 @@ import (
_ "github.com/ollama/ollama/model/models"
)
// input is an element of the prompt to process, either a token or an image
type input struct {
token int32
image image.Image
}
type Sequence struct {
// ctx for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctx ml.Context
// batch index
iBatch int
// prompt inputs left to evaluate
inputs []input
inputs []model.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input
pendingInputs []model.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@ -101,8 +97,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
s.ready.Wait()
startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(prompt, images)
inputs, err := s.inputs(ctx, prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@ -128,6 +125,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctx: ctx,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@ -146,19 +144,22 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
var inputs []model.Input
var parts []string
var matches [][]string
// TODO(jessegross): This can sometimes trigger for matching text in the
// user's prompt. We previously tried to avoid it by only looking for images
// on image models. We don't have a clear indication now but it would be better
// to properly escape it in any case.
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
if visionModel {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
@ -167,7 +168,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
}
for _, t := range tokens {
inputs = append(inputs, input{token: t})
inputs = append(inputs, model.Input{Token: t})
}
// image - decode and store
@ -186,12 +187,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n)
}
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, err
}
inputs = append(inputs, input{image: image})
s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil {
return nil, err
}
}
@ -238,6 +252,10 @@ type Server struct {
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
}
func (s *Server) allNil() bool {
@ -283,6 +301,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@ -311,7 +330,6 @@ func (s *Server) processBatch() error {
defer s.mu.Unlock()
var options model.Options
imgSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
@ -330,7 +348,7 @@ func (s *Server) processBatch() error {
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input{}
seq.cache.Inputs = []model.Input{}
}
for i, input := range seq.inputs {
@ -349,25 +367,21 @@ func (s *Server) processBatch() error {
break
}
// TODO(jessegross): Image inputs need to be rethought - it's
// it doesn't work well for different types of models or multiple sequences
if input.image != nil {
if len(seq.pendingInputs) != len(options.Images) {
break
}
if imgSeq != seqIdx && imgSeq != -1 {
s.nextSeq = seqIdx
break
}
imgSeq = seqIdx
options.Images = append(options.Images, input.image)
seq.pendingInputs = append(seq.pendingInputs, input)
continue
// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
// to the encoder cache.
//
// Break the batch when switching from text to images so that images are always at the beginning.
if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
s.nextSeq = seqIdx
break
}
options.Inputs = append(options.Inputs, input.Token)
if input.Multimodal != nil {
options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
}
options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
@ -403,7 +417,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input{}
seq.pendingInputs = []model.Input{}
}
// don't sample prompt processing
@ -449,7 +463,7 @@ func (s *Server) processBatch() error {
return err
}
seq.inputs = []input{{token: token}}
seq.inputs = []model.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")