diff --git a/model/model.go b/model/model.go index fadea3246..53e47add9 100644 --- a/model/model.go +++ b/model/model.go @@ -60,7 +60,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []input.Input) ([]input.Input, error) + PostTokenize([]input.Input) ([]input.Input, error) } // Base implements the common fields and methods for all models diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index ccc7567c5..32ad80f43 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 071d77ac7..fa4d570ca 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var images []input.Input fnvHash := fnv.New64a() for i := range inputs { if inputs[i].Multimodal == nil { if len(images) > 0 { - inputs[i].Multimodal = images[0].Multimodal + inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].MultimodalHash = images[0].MultimodalHash for j := 1; j < len(images); j++ { - inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) + inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) fnvHash.Reset() binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) @@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(opts.Multimodal) > 0 { - crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) + images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) + if len(images) > 0 { + crossAttentionStates = images[len(images)-1] + } } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 916ad45da..d4c24556c 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -34,10 +34,14 @@ import ( _ "github.com/ollama/ollama/model/models" ) +type contextList struct { + list []ml.Context +} + type Sequence struct { - // ctx for allocating tensors that last the lifetime of the sequence, such as + // ctxs are used for allocating tensors that last the lifetime of the sequence, such as // multimodal embeddings - ctx ml.Context + ctxs *contextList // batch index iBatch int @@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe s.ready.Wait() startTime := time.Now() - ctx := s.model.Backend().NewContext() - inputs, err := s.inputs(ctx, prompt, images) + inputs, ctxs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctx: ctx, + ctxs: ctxs, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) { var inputs []input.Input var parts []string var matches [][]string @@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( parts = []string{prompt} } + var contexts contextList + runtime.AddCleanup(&contexts, func(ctxs []ml.Context) { + for _, ctx := range ctxs { + ctx.Close() + } + }, contexts.list) + postTokenize := false for i, part := range parts { // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, err + return nil, nil, err } for _, t := range tokens { @@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( } if imageIndex < 0 { - return nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, fmt.Errorf("invalid image index: %d", n) } + ctx := s.model.Backend().NewContext() + contexts.list = append(contexts.list, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, err + return nil, nil, err } s.multimodalHash.Reset() @@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( if visionModel && postTokenize { var err error - inputs, err = multimodalProcessor.PostTokenize(ctx, inputs) + inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, err + return nil, nil, err } } - return inputs, nil + return inputs, &contexts, nil } type Server struct { @@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false - seq.ctx.Close() s.seqs[seqIndex] = nil s.seqsSem.Release(1) }