ollamarunner: Ensure batch size limits are not exceeded
With the llama runner, we can generate up to NUM_PARALLEL batches at once, which will then get broken up to into individual batches to get executed by llama.cpp (i.e. we add up to 2048 tokens and this gets split into 4 batches of 512 tokens at default settings). This splitting can improve parallelism on multi-GPU systems because the individual batches can move though the pipeline without blocking on the first one to fully complete. However, we don't yet support this in the Ollama runner, partially because it makes it hard to enforce model-specified batch constraints, which didn't exist previously. The result is that we will try to execute the full, unsplit batch. This could result in out of memory or insufficient KV cache space errors. This triggers batch breaking when the total inputs from all sequences exceeds the batch size, rather than per-sequence. In order to ensure fairness, it also reintroduces round-robinning around sequences so that we don't let one busy sequence starve the others.
This commit is contained in:
parent
071a9872cb
commit
5d097277ef
@ -267,6 +267,9 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
@ -351,14 +354,19 @@ func (s *Server) processBatch() error {
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
resumeSeq := -1
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, "limit")
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
@ -369,16 +377,23 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
for i, inp := range seq.inputs {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
// will cause a break if we have existing inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
// Stop if the required batch would put us over the total batch size (including tokens
|
||||
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||
// check if other sequences can still squeeze something in.
|
||||
if len(batchInputs)+minBatch > batchSize {
|
||||
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||
resumeSeq = seqIdx
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@ -405,7 +420,7 @@ func (s *Server) processBatch() error {
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
@ -414,6 +429,12 @@ func (s *Server) processBatch() error {
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
s.nextSeq = seqIdx + 1
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user