From 5d097277ef8b08c86f354b54596976869998257d Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 27 Mar 2025 14:00:05 -0700 Subject: [PATCH] ollamarunner: Ensure batch size limits are not exceeded With the llama runner, we can generate up to NUM_PARALLEL batches at once, which will then get broken up to into individual batches to get executed by llama.cpp (i.e. we add up to 2048 tokens and this gets split into 4 batches of 512 tokens at default settings). This splitting can improve parallelism on multi-GPU systems because the individual batches can move though the pipeline without blocking on the first one to fully complete. However, we don't yet support this in the Ollama runner, partially because it makes it hard to enforce model-specified batch constraints, which didn't exist previously. The result is that we will try to execute the full, unsplit batch. This could result in out of memory or insufficient KV cache space errors. This triggers batch breaking when the total inputs from all sequences exceeds the batch size, rather than per-sequence. In order to ensure fairness, it also reintroduces round-robinning around sequences so that we don't let one busy sequence starve the others. --- runner/ollamarunner/runner.go | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 31d20db80..6d20fa85b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -267,6 +267,9 @@ type Server struct { // KV cache cache *InputCache + // next sequence for prompt processing to avoid starvation + nextSeq int + // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -351,14 +354,19 @@ func (s *Server) processBatch() error { var batchInputs []int32 var batch input.Batch - for i, seq := range s.seqs { + resumeSeq := -1 + seqIdx := s.nextSeq - 1 + for range s.seqs { + seqIdx = (seqIdx + 1) % len(s.seqs) + seq := s.seqs[seqIdx] + if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(i, "limit") + s.removeSequence(seqIdx, "limit") continue } @@ -369,16 +377,23 @@ func (s *Server) processBatch() error { batchSize := s.batchSize - for j, inp := range seq.inputs { + for i, inp := range seq.inputs { // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this - // will cause a break if we have pending inputs. + // will cause a break if we have existing inputs. minBatch := 1 + inp.SameBatch if minBatch > batchSize { batchSize = minBatch } - if len(seq.pendingInputs)+minBatch > batchSize { + // Stop if the required batch would put us over the total batch size (including tokens + // added by other sequences). If we haven't been able to add anything yet then pick up + // here again for the next batch to avoid starvation, though we can opportunistically + // check if other sequences can still squeeze something in. + if len(batchInputs)+minBatch > batchSize { + if len(seq.pendingInputs) == 0 && resumeSeq == -1 { + resumeSeq = seqIdx + } break } @@ -405,7 +420,7 @@ func (s *Server) processBatch() error { batch.Sequences = append(batch.Sequences, seq.cache.Id) seq.iBatch = len(batch.Outputs) - if j+1 == len(seq.inputs) { + if i+1 == len(seq.inputs) { batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) @@ -414,6 +429,12 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } + if resumeSeq != -1 { + s.nextSeq = resumeSeq + } else { + s.nextSeq = seqIdx + 1 + } + if len(batchInputs) == 0 { return nil }