ollamarunner: Improve multimodal input handling

Various vision models have different requirements for how they
receive their inputs. For example:
 - Mllama wants images together with text and the image embeddings
   don't themselves have positions or get stored in the main KV cache
 - Llava-style models feed in embeddings similar to tokens and
   images correspond to a varying number of tokens in the cache.

In addition, the strategy for providing inputs must support batching
and multiple sequences, which are managed by the runner. At the same
time, we want to keep data handling fully in the model so that new
architectures are not bottlenecked by runner code which does not
understand their particular requirements.

This provides a method for models to edit the input stream so that
it meets their needs while still being in a format that the runner
understands. This allows the runner to avoid special processing
for different models.

In addition, this fixes a regression where non-vision models may
try to incorrectly interpret images.
This commit is contained in:
Jesse Gross 2025-03-05 12:08:06 -08:00 committed by Jesse Gross
parent b70fc4d51e
commit a7e63b82be
5 changed files with 247 additions and 130 deletions

View File

@ -3,7 +3,6 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"image"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog" "log/slog"
@ -22,14 +21,40 @@ import (
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
) )
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass // Options contains the inputs for a model forward pass
type Options struct { type Options struct {
Inputs []int32 Inputs []int32
Multimodal []MultimodalIndex
Positions []int32 Positions []int32
Sequences []int Sequences []int
Outputs []int32 Outputs []int32
Images []image.Image
} }
type config struct { type config struct {
@ -59,6 +84,37 @@ type Model interface {
Config() config Config() config
} }
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []Input) ([]Input, error)
}
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture // Register registers a model constructor for the given architecture

View File

@ -1,7 +1,12 @@
package mllama package mllama
import ( import (
"bytes"
"encoding/binary"
"fmt" "fmt"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -56,10 +61,13 @@ func New(c ml.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
var crossAttentionStates ml.Tensor image, _, err := image.Decode(bytes.NewReader(multimodalData))
if opts.Images != nil { if err != nil {
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0]) return nil, err
}
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,8 +97,43 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates) return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
var images []model.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if opts.Multimodal != nil {
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
} }
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"reflect"
"time" "time"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots) slots := make([]InputCacheSlot, numSlots)
for i := range slots { for i := range slots {
slots[i] = InputCacheSlot{ slots[i] = InputCacheSlot{Id: i}
Id: i,
Inputs: make([]input, 0),
}
} }
cache := model.Config().Cache cache := model.Config().Cache
@ -83,7 +79,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []input Inputs []model.Input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@ -92,7 +88,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest) oldestSlot.Inputs = make([]model.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []input, b []input) int32 { func countCommonPrefix(a []model.Input, b []model.Input) int32 {
var count int32 var count int32
for i := range a { for i := range a {
@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break break
} }
if !reflect.DeepEqual(a[i], b[i]) { if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break break
} }

View File

@ -4,6 +4,8 @@ import (
"image" "image"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/model"
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []input t1 []model.Input
t2 []input t2 []model.Input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}}, t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []input{{token: 1}}, t1: []model.Input{{Token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input{{image: imgA}}, t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input{{token: 1}, {image: imgA}}, t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input{{token: 1}, {image: imgA}, {token: 5}}, t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{
name: "Mixed, Same Length",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{ {
name: "Empty", name: "Empty",
t1: []input{}, t1: []model.Input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []input{}, t1: []model.Input{},
t2: []input{}, t2: []model.Input{},
expected: 0, expected: 0,
}, },
} }
@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input prompt []model.Input
longest expected longest expected
best expected best expected
}{ }{
@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input{{token: 1}}, prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 1}, {token: 2}}, prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input{{token: 2}}, prompt: []model.Input{{Token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []input{{token: 1}}, prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 2}, {token: 3}}, prompt: []model.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 1}, {token: 2}}, prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },

View File

@ -1,13 +1,12 @@
package ollamarunner package ollamarunner
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"image" "hash/maphash"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@ -33,22 +32,19 @@ import (
_ "github.com/ollama/ollama/model/models" _ "github.com/ollama/ollama/model/models"
) )
// input is an element of the prompt to process, either a token or an image
type input struct {
token int32
image image.Image
}
type Sequence struct { type Sequence struct {
// ctx for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctx ml.Context
// batch index // batch index
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []input inputs []model.Input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input pendingInputs []model.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@ -101,8 +97,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(prompt, images) inputs, err := s.inputs(ctx, prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
@ -128,6 +125,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
return &Sequence{ return &Sequence{
ctx: ctx,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
@ -146,19 +144,22 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
var inputs []input var inputs []model.Input
var parts []string var parts []string
var matches [][]string var matches [][]string
// TODO(jessegross): This can sometimes trigger for matching text in the multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
// user's prompt. We previously tried to avoid it by only looking for images
// on image models. We don't have a clear indication now but it would be better if visionModel {
// to properly escape it in any case.
re := regexp.MustCompile(`\[img-(\d+)\]`) re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1) parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
postTokenize := false
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
@ -167,7 +168,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, input{token: t}) inputs = append(inputs, model.Input{Token: t})
} }
// image - decode and store // image - decode and store
@ -186,12 +187,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inputs = append(inputs, input{image: image}) s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil {
return nil, err
} }
} }
@ -238,6 +252,10 @@ type Server struct {
// next sequence for prompt processing to avoid starvation // next sequence for prompt processing to avoid starvation
nextSeq int nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
} }
func (s *Server) allNil() bool { func (s *Server) allNil() bool {
@ -283,6 +301,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
@ -311,7 +330,6 @@ func (s *Server) processBatch() error {
defer s.mu.Unlock() defer s.mu.Unlock()
var options model.Options var options model.Options
imgSeq := -1
seqIdx := s.nextSeq - 1 seqIdx := s.nextSeq - 1
for range s.seqs { for range s.seqs {
@ -330,7 +348,7 @@ func (s *Server) processBatch() error {
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input{} seq.cache.Inputs = []model.Input{}
} }
for i, input := range seq.inputs { for i, input := range seq.inputs {
@ -349,25 +367,21 @@ func (s *Server) processBatch() error {
break break
} }
// TODO(jessegross): Image inputs need to be rethought - it's // TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
// it doesn't work well for different types of models or multiple sequences // to the encoder cache.
if input.image != nil { //
if len(seq.pendingInputs) != len(options.Images) { // Break the batch when switching from text to images so that images are always at the beginning.
break if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
} (len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
if imgSeq != seqIdx && imgSeq != -1 {
s.nextSeq = seqIdx s.nextSeq = seqIdx
break break
} }
imgSeq = seqIdx options.Inputs = append(options.Inputs, input.Token)
options.Images = append(options.Images, input.image) if input.Multimodal != nil {
seq.pendingInputs = append(seq.pendingInputs, input) options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
continue
} }
options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id) options.Sequences = append(options.Sequences, seq.cache.Id)
@ -403,7 +417,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache // After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input{} seq.pendingInputs = []model.Input{}
} }
// don't sample prompt processing // don't sample prompt processing
@ -449,7 +463,7 @@ func (s *Server) processBatch() error {
return err return err
} }
seq.inputs = []input{{token: token}} seq.inputs = []model.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")