diff --git a/llama/llama.go b/llama/llama.go index a026bee24..e8cdafe7f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() { C.llama_kv_cache_defrag(c.c) } +func (c *Context) KvCacheCanShift() bool { + return bool(C.llama_kv_cache_can_shift(c.c)) +} + // Get the embeddings for a sequence id func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index d29e94b6b..2e55b09dc 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { return discard } -// Frees up space in the KV cache by deleting the oldest half of history and shifting -// the newest half into that space (saving numKeep inputs at the beginning). +type ErrReprocessInputs struct { + Inputs []input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + +// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history +// and shifting the newest half into that space (saving numKeep inputs at the beginning). // // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { @@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) } - discard := c.ShiftDiscard(len(slot.Inputs), numKeep) + inputLen := len(slot.Inputs) + discard := c.ShiftDiscard(inputLen, numKeep) if discard <= 0 { return nil @@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models - if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) - } - c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) + var shiftFailed bool - for i := numKeep + discard; i < len(slot.Inputs); i++ { + if c.lc.KvCacheCanShift() { + // For models that support shifting, attempt to shift the KV cache + if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { + shiftFailed = true + slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } else { + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard) + } + } else { + // For models that don't support shifting + shiftFailed = true + slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } + + if shiftFailed { + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Clear the entire KV cache + _ = c.lc.KvCacheSeqRm(slot.Id, 0, -1) + // Reset the slot inputs since we've cleared the cache + slot.Inputs = []input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} + } + + // Standard shift succeeded - update input array + for i := numKeep + discard; i < inputLen; i++ { slot.Inputs[i-discard] = slot.Inputs[i] } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] + slot.Inputs = slot.Inputs[:inputLen-discard] return nil } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index ee5d47f6e..a4264f5fc 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Continue processing as normal + continue + } else { + return err + } } } else { break diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index aa56c9822..af48ff22e 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { return discard } +type ErrReprocessInputs struct { + Inputs []input.Input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + // Frees up space in the KV cache by deleting the oldest half of history and shifting // the newest half into that space (saving numKeep inputs at the beginning). // @@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models if c.cache != nil { err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) if err != nil { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing", + "id", slot.Id, "error", err) + + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Reset the cache + _ = c.cache.Remove(slot.Id, 0, -1) + slot.Inputs = []input.Input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} } } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index f8925d119..6a8d8a6a9 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -1,10 +1,13 @@ package ollamarunner import ( + "errors" + "fmt" "image" "testing" "time" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/input" ) @@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) { }) } } + +// Mock implementation of the Cache interface +type mockCache struct { + shouldFail bool +} + +// Implement only the methods needed for the test +func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error { + if m.shouldFail { + return fmt.Errorf("mock cache removal error") + } + return nil +} + +// Stub implementations for other interface methods +func (m *mockCache) SetLayer(layer int) {} +func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil } +func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} +func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} +func (m *mockCache) Close() {} +func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } +func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} +func (m *mockCache) SetConfig(ml.CacheConfig) {} + +func TestShiftCacheSlot(t *testing.T) { + tests := []struct { + name string + numCtx int32 + inputs []input.Input + numKeep int32 + cacheErr bool + wantErr any + wantInputsLen int + }{ + { + name: "Normal shift", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: false, // No error + wantErr: nil, + wantInputsLen: 6, // After discarding 4 tokens + }, + { + name: "Cache removal fails", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: true, + wantErr: &ErrReprocessInputs{}, + wantInputsLen: 0, // Original inputs should be cleared + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockCache{shouldFail: tt.cacheErr} + c := InputCache{ + numCtx: tt.numCtx, + cache: mock, + } + slot := &InputCacheSlot{ + Id: 123, + Inputs: make([]input.Input, len(tt.inputs)), + } + copy(slot.Inputs, tt.inputs) + + err := c.ShiftCacheSlot(slot, tt.numKeep) + + if tt.wantErr != nil { + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + if !errors.As(err, &tt.wantErr) { + t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(slot.Inputs) != tt.wantInputsLen { + t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index bc7a07ed6..458387184 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -407,7 +407,15 @@ func (s *Server) processBatch() error { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Skip this sequence but continue processing the rest + continue + } else { + return err + } } }