diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 98fcf2c04..9a1a549cd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } + // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, + // otherwise we might truncate or split the batch against the model's wishes + // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) @@ -366,17 +369,6 @@ func (s *Server) processBatch() error { batchSize := s.batchSize for j, inp := range seq.inputs { - if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { - if len(seq.pendingInputs) == 0 { - err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) - if err != nil { - return err - } - } else { - break - } - } - // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this // will cause a break if we have pending inputs. @@ -389,6 +381,20 @@ func (s *Server) processBatch() error { break } + // If the sum of our working set (already processed tokens, tokens we added to this + // batch, required following tokens) exceeds the context size, then trigger a shift + // now so we don't have to do one later when we can't break the batch. + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx { + if len(seq.pendingInputs) != 0 { + break + } + + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } + options.Inputs = append(options.Inputs, inp.Token) if inp.Multimodal != nil { options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})