ollamarunner: Ensure batch size limits are not exceeded

With the llama runner, we can generate up to NUM_PARALLEL batches
at once, which will then get broken up to into individual batches
to get executed by llama.cpp (i.e. we add up to 2048 tokens and
this gets split into 4 batches of 512 tokens at default settings).

This splitting can improve parallelism on multi-GPU systems because
the individual batches can move though the pipeline without blocking
on the first one to fully complete. However, we don't yet support
this in the Ollama runner, partially because it makes it hard to
enforce model-specified batch constraints, which didn't exist
previously.

The result is that we will try to execute the full, unsplit batch.
This could result in out of memory or insufficient KV cache space
errors.

This triggers batch breaking when the total inputs from all sequences
exceeds the batch size, rather than per-sequence. In order to ensure
fairness, it also reintroduces round-robinning around sequences so
that we don't let one busy sequence starve the others.
This commit is contained in:
Jesse Gross 2025-03-27 14:00:05 -07:00 committed by Jesse Gross
parent 071a9872cb
commit 5d097277ef

View File

@ -267,6 +267,9 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
@ -351,14 +354,19 @@ func (s *Server) processBatch() error {
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
resumeSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, "limit")
s.removeSequence(seqIdx, "limit")
continue
}
@ -369,16 +377,23 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize
for j, inp := range seq.inputs {
for i, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
// will cause a break if we have existing inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if len(seq.pendingInputs)+minBatch > batchSize {
// Stop if the required batch would put us over the total batch size (including tokens
// added by other sequences). If we haven't been able to add anything yet then pick up
// here again for the next batch to avoid starvation, though we can opportunistically
// check if other sequences can still squeeze something in.
if len(batchInputs)+minBatch > batchSize {
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
resumeSeq = seqIdx
}
break
}
@ -405,7 +420,7 @@ func (s *Server) processBatch() error {
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
@ -414,6 +429,12 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
s.nextSeq = seqIdx + 1
}
if len(batchInputs) == 0 {
return nil
}