ollamarunner: Use a separate context per multimodal input
Currently there is a single context per sequence, shared all by all multimodal inputs. Since we build a vision encoder graph per image, with a large number of inputs we can eventually hit the maximum number of graph nodes per context. This changes to use a separate context for each image, ensuring that available resource limits are consistent.
This commit is contained in:
parent
9679f40146
commit
282bfaaa95
@ -60,7 +60,7 @@ type MultimodalProcessor interface {
|
|||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
// represents the contents.
|
// represents the contents.
|
||||||
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
PostTokenize([]input.Input) ([]input.Input, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
|
@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return visionOutputs, nil
|
return visionOutputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
var result []input.Input
|
var result []input.Input
|
||||||
|
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
|
@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
var images []input.Input
|
var images []input.Input
|
||||||
fnvHash := fnv.New64a()
|
fnvHash := fnv.New64a()
|
||||||
|
|
||||||
for i := range inputs {
|
for i := range inputs {
|
||||||
if inputs[i].Multimodal == nil {
|
if inputs[i].Multimodal == nil {
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
inputs[i].Multimodal = images[0].Multimodal
|
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
||||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||||
for j := 1; j < len(images); j++ {
|
for j := 1; j < len(images); j++ {
|
||||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
||||||
fnvHash.Reset()
|
fnvHash.Reset()
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||||
@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
|
|||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(opts.Multimodal) > 0 {
|
||||||
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||||
|
if len(images) > 0 {
|
||||||
|
crossAttentionStates = images[len(images)-1]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
|
@ -34,10 +34,14 @@ import (
|
|||||||
_ "github.com/ollama/ollama/model/models"
|
_ "github.com/ollama/ollama/model/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type contextList struct {
|
||||||
|
list []ml.Context
|
||||||
|
}
|
||||||
|
|
||||||
type Sequence struct {
|
type Sequence struct {
|
||||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
||||||
// multimodal embeddings
|
// multimodal embeddings
|
||||||
ctx ml.Context
|
ctxs *contextList
|
||||||
|
|
||||||
// batch index
|
// batch index
|
||||||
iBatch int
|
iBatch int
|
||||||
@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
ctx := s.model.Backend().NewContext()
|
|
||||||
|
|
||||||
inputs, err := s.inputs(ctx, prompt, images)
|
inputs, ctxs, err := s.inputs(prompt, images)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||||
} else if len(inputs) == 0 {
|
} else if len(inputs) == 0 {
|
||||||
@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// TODO(jessegross): Ingest cached history for grammar
|
// TODO(jessegross): Ingest cached history for grammar
|
||||||
|
|
||||||
return &Sequence{
|
return &Sequence{
|
||||||
ctx: ctx,
|
ctxs: ctxs,
|
||||||
inputs: inputs,
|
inputs: inputs,
|
||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
||||||
var inputs []input.Input
|
var inputs []input.Input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
|
|||||||
parts = []string{prompt}
|
parts = []string{prompt}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var contexts contextList
|
||||||
|
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
||||||
|
for _, ctx := range ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}, contexts.list)
|
||||||
|
|
||||||
postTokenize := false
|
postTokenize := false
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
// text - tokenize
|
// text - tokenize
|
||||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if imageIndex < 0 {
|
if imageIndex < 0 {
|
||||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := s.model.Backend().NewContext()
|
||||||
|
contexts.list = append(contexts.list, ctx)
|
||||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.multimodalHash.Reset()
|
s.multimodalHash.Reset()
|
||||||
@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
|
|||||||
|
|
||||||
if visionModel && postTokenize {
|
if visionModel && postTokenize {
|
||||||
var err error
|
var err error
|
||||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return inputs, nil
|
return inputs, &contexts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
|||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
seq.cache.InUse = false
|
||||||
seq.ctx.Close()
|
|
||||||
s.seqs[seqIndex] = nil
|
s.seqs[seqIndex] = nil
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user