ollamarunner: Use a separate context per multimodal input

Currently there is a single context per sequence, shared all by
all multimodal inputs. Since we build a vision encoder graph per
image, with a large number of inputs we can eventually hit the
maximum number of graph nodes per context.

This changes to use a separate context for each image, ensuring
that available resource limits are consistent.
This commit is contained in:
Jesse Gross 2025-03-13 20:32:50 -07:00 committed by Jesse Gross
parent 9679f40146
commit 282bfaaa95
4 changed files with 33 additions and 19 deletions

View File

@ -60,7 +60,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
PostTokenize([]input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models

View File

@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {

View File

@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 {
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

View File

@ -34,10 +34,14 @@ import (
_ "github.com/ollama/ollama/model/models"
)
type contextList struct {
list []ml.Context
}
type Sequence struct {
// ctx for allocating tensors that last the lifetime of the sequence, such as
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctx ml.Context
ctxs *contextList
// batch index
iBatch int
@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
s.ready.Wait()
startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(ctx, prompt, images)
inputs, ctxs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctx: ctx,
ctxs: ctxs,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
var inputs []input.Input
var parts []string
var matches [][]string
@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
parts = []string{prompt}
}
var contexts contextList
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
for _, ctx := range ctxs {
ctx.Close()
}
}, contexts.list)
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, err
return nil, nil, err
}
for _, t := range tokens {
@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
}
if imageIndex < 0 {
return nil, fmt.Errorf("invalid image index: %d", n)
return nil, nil, fmt.Errorf("invalid image index: %d", n)
}
ctx := s.model.Backend().NewContext()
contexts.list = append(contexts.list, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, err
return nil, nil, err
}
s.multimodalHash.Reset()
@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return nil, err
return nil, nil, err
}
}
return inputs, nil
return inputs, &contexts, nil
}
type Server struct {
@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}