diff --git a/kvcache/cache.go b/kvcache/cache.go index 18aec8003..07015b9e0 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -62,6 +62,11 @@ type Cache interface { // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) + // CanResume returns true if the cache can continue with the next token at + // the given position and sequence. Assumes that the caller has already + // verified the contents of the cache. + CanResume(seq int, pos int32) bool + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set // endIndex to math.MaxInt32 to remove everything starting at beginIndex. // diff --git a/kvcache/causal.go b/kvcache/causal.go index fb4f0f743..4fc18d88f 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { c.cellRanges[dstSeq] = seqRange } +func (c *Causal) CanResume(seq int, pos int32) bool { + if c.windowSize == math.MaxInt32 { + return true + } + + seqRange, ok := c.cellRanges[seq] + if !ok { + return false + } + + // for sliding window, check that the window of the new sequence is contained in + // the window of what we are storing + var last int32 = -1 + for i := seqRange.min; i <= seqRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + last = max(last, c.cells[i].pos) + } + } + + if last == -1 { + return false + } + + lastWindowStart := max(0, last-c.windowSize) + posWindowStart := max(0, pos-c.windowSize) + + return posWindowStart >= lastWindowStart +} + func (c *Causal) shift(seq int, beginIndex, offset int32) error { if c.shiftFn == nil { return ErrNotSupported @@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + // TODO(jessegross): We should check to see if removing the middle of the sequence will + // cause the sliding window to encompass tokens that we no longer have. If so, then we + // should return an error, which will trigger the runner to evaluate the full history and + // rebuild the window. However, if we have multimodal inputs in our history, this reuse + // results in use after free, so we don't do it for now. + var offset int32 if endIndex != math.MaxInt32 { offset = beginIndex - endIndex @@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { } else { if c.cells[i].pos >= endIndex { if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { - // TODO(jessegross): Need to be careful about data shared between sequences - return errors.New("shifting on cells shared by multiple sequences not yet implemented") + return errors.New("shifting cells shared by multiple sequences not supported") } c.cells[i].pos += offset diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index b1dc7d779..bf98abef6 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -300,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } } +func TestCanResume(t *testing.T) { + backend := &testBackend{} + windowSize := int32(4) + cache := NewSWACache(windowSize, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3}, + Sequences: []int{0, 0, 0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + cache.Put(context, tensor, tensor) + + // with window size 4, nothing has slid out of the window yet + if !cache.CanResume(0, 0) { + t.Errorf("CanResume(0, 0) = false, want true (within window)") + } + if !cache.CanResume(0, 1) { + t.Errorf("CanResume(0, 1) = false, want true (within window)") + } + if !cache.CanResume(0, 2) { + t.Errorf("CanResume(0, 2) = false, want true (within window)") + } + if !cache.CanResume(0, 3) { + t.Errorf("CanResume(0, 3) = false, want true (latest position)") + } + + // shift window by adding position 4 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{4, 5}, + Sequences: []int{0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + cache.Put(context, tensor, tensor) + + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if !cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") + } +} + type testBackend struct{} func (b *testBackend) Config() ml.Config { diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 07ff4291e..03d650a3f 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -134,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { panic("encoder cache does not support multiple sequences") } +func (c *EncoderCache) CanResume(seq int, pos int32) bool { + return true +} + func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { if c.encoderPos >= beginIndex && c.encoderPos < endIndex { c.encoderCached = false diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 0e8ff1f32..926bc2d41 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { } } +func (c *WrapperCache) CanResume(seq int, pos int32) bool { + for _, cache := range c.caches { + if !cache.CanResume(seq, pos) { + return false + } + } + + return true +} + func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail for _, cache := range c.caches { diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 30292f641..01f435e4b 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -118,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp } if c.cache != nil { + if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) { + numPast = 0 + } + err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) if err != nil { // Some models don't support partial erasure diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 6a8d8a6a9..543b4b2fa 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -451,6 +451,7 @@ func (m *mockCache) Close() func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) SetConfig(ml.CacheConfig) {} +func (m *mockCache) CanResume(seq int, pos int32) bool { return true } func TestShiftCacheSlot(t *testing.T) { tests := []struct {