diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 31d20db80..6d20fa85b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -267,6 +267,9 @@ type Server struct { // KV cache cache *InputCache + // next sequence for prompt processing to avoid starvation + nextSeq int + // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -351,14 +354,19 @@ func (s *Server) processBatch() error { var batchInputs []int32 var batch input.Batch - for i, seq := range s.seqs { + resumeSeq := -1 + seqIdx := s.nextSeq - 1 + for range s.seqs { + seqIdx = (seqIdx + 1) % len(s.seqs) + seq := s.seqs[seqIdx] + if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(i, "limit") + s.removeSequence(seqIdx, "limit") continue } @@ -369,16 +377,23 @@ func (s *Server) processBatch() error { batchSize := s.batchSize - for j, inp := range seq.inputs { + for i, inp := range seq.inputs { // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this - // will cause a break if we have pending inputs. + // will cause a break if we have existing inputs. minBatch := 1 + inp.SameBatch if minBatch > batchSize { batchSize = minBatch } - if len(seq.pendingInputs)+minBatch > batchSize { + // Stop if the required batch would put us over the total batch size (including tokens + // added by other sequences). If we haven't been able to add anything yet then pick up + // here again for the next batch to avoid starvation, though we can opportunistically + // check if other sequences can still squeeze something in. + if len(batchInputs)+minBatch > batchSize { + if len(seq.pendingInputs) == 0 && resumeSeq == -1 { + resumeSeq = seqIdx + } break } @@ -405,7 +420,7 @@ func (s *Server) processBatch() error { batch.Sequences = append(batch.Sequences, seq.cache.Id) seq.iBatch = len(batch.Outputs) - if j+1 == len(seq.inputs) { + if i+1 == len(seq.inputs) { batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) @@ -414,6 +429,12 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } + if resumeSeq != -1 { + s.nextSeq = resumeSeq + } else { + s.nextSeq = seqIdx + 1 + } + if len(batchInputs) == 0 { return nil }