From 66b253923891d41a31d28531e9db5efccf53e1d0 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 31 Mar 2025 12:54:45 -0700 Subject: [PATCH] runner: clear cache when shift is not possible (#9433) Clear KV cache when shift operation is not supported by model. Added KvCacheCanShift() check to handle models that can't perform cache shifts, falling back to full cache clear while preserving logical token history to maintain expected behavior when context window fills up. --- llama/llama.go | 4 ++ runner/llamarunner/cache.go | 55 +++++++++++++++---- runner/llamarunner/runner.go | 10 +++- runner/ollamarunner/cache.go | 24 +++++++- runner/ollamarunner/cache_test.go | 91 +++++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 10 +++- 6 files changed, 180 insertions(+), 14 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index a026bee24..e8cdafe7f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() { C.llama_kv_cache_defrag(c.c) } +func (c *Context) KvCacheCanShift() bool { + return bool(C.llama_kv_cache_can_shift(c.c)) +} + // Get the embeddings for a sequence id func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index d29e94b6b..2e55b09dc 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { return discard } -// Frees up space in the KV cache by deleting the oldest half of history and shifting -// the newest half into that space (saving numKeep inputs at the beginning). +type ErrReprocessInputs struct { + Inputs []input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + +// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history +// and shifting the newest half into that space (saving numKeep inputs at the beginning). // // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { @@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) } - discard := c.ShiftDiscard(len(slot.Inputs), numKeep) + inputLen := len(slot.Inputs) + discard := c.ShiftDiscard(inputLen, numKeep) if discard <= 0 { return nil @@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models - if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) - } - c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) + var shiftFailed bool - for i := numKeep + discard; i < len(slot.Inputs); i++ { + if c.lc.KvCacheCanShift() { + // For models that support shifting, attempt to shift the KV cache + if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { + shiftFailed = true + slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } else { + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard) + } + } else { + // For models that don't support shifting + shiftFailed = true + slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } + + if shiftFailed { + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Clear the entire KV cache + _ = c.lc.KvCacheSeqRm(slot.Id, 0, -1) + // Reset the slot inputs since we've cleared the cache + slot.Inputs = []input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} + } + + // Standard shift succeeded - update input array + for i := numKeep + discard; i < inputLen; i++ { slot.Inputs[i-discard] = slot.Inputs[i] } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] + slot.Inputs = slot.Inputs[:inputLen-discard] return nil } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index ee5d47f6e..a4264f5fc 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Continue processing as normal + continue + } else { + return err + } } } else { break diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index aa56c9822..af48ff22e 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { return discard } +type ErrReprocessInputs struct { + Inputs []input.Input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + // Frees up space in the KV cache by deleting the oldest half of history and shifting // the newest half into that space (saving numKeep inputs at the beginning). // @@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models if c.cache != nil { err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) if err != nil { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing", + "id", slot.Id, "error", err) + + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Reset the cache + _ = c.cache.Remove(slot.Id, 0, -1) + slot.Inputs = []input.Input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} } } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index f8925d119..6a8d8a6a9 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -1,10 +1,13 @@ package ollamarunner import ( + "errors" + "fmt" "image" "testing" "time" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/input" ) @@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) { }) } } + +// Mock implementation of the Cache interface +type mockCache struct { + shouldFail bool +} + +// Implement only the methods needed for the test +func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error { + if m.shouldFail { + return fmt.Errorf("mock cache removal error") + } + return nil +} + +// Stub implementations for other interface methods +func (m *mockCache) SetLayer(layer int) {} +func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil } +func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} +func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} +func (m *mockCache) Close() {} +func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } +func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} +func (m *mockCache) SetConfig(ml.CacheConfig) {} + +func TestShiftCacheSlot(t *testing.T) { + tests := []struct { + name string + numCtx int32 + inputs []input.Input + numKeep int32 + cacheErr bool + wantErr any + wantInputsLen int + }{ + { + name: "Normal shift", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: false, // No error + wantErr: nil, + wantInputsLen: 6, // After discarding 4 tokens + }, + { + name: "Cache removal fails", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: true, + wantErr: &ErrReprocessInputs{}, + wantInputsLen: 0, // Original inputs should be cleared + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockCache{shouldFail: tt.cacheErr} + c := InputCache{ + numCtx: tt.numCtx, + cache: mock, + } + slot := &InputCacheSlot{ + Id: 123, + Inputs: make([]input.Input, len(tt.inputs)), + } + copy(slot.Inputs, tt.inputs) + + err := c.ShiftCacheSlot(slot, tt.numKeep) + + if tt.wantErr != nil { + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + if !errors.As(err, &tt.wantErr) { + t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(slot.Inputs) != tt.wantInputsLen { + t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index bc7a07ed6..458387184 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -407,7 +407,15 @@ func (s *Server) processBatch() error { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Skip this sequence but continue processing the rest + continue + } else { + return err + } } }