diff --git a/kvcache/causal.go b/kvcache/causal.go index b2e7b3ab0..6a927cb80 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -55,8 +55,8 @@ type Causal struct { shiftFn shiftFn backend ml.Backend - cacheCtx ml.Context - keys, values []ml.Tensor + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor } type cacheCell struct { @@ -70,11 +70,23 @@ type cellRange struct { } func NewCausalCache(shift shiftFn) *Causal { - return &Causal{windowSize: math.MaxInt32, shiftFn: shift} + return &Causal{ + windowSize: math.MaxInt32, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func NewSWACache(windowSize int32, shift shiftFn) *Causal { - return &Causal{windowSize: windowSize, shiftFn: shift} + return &Causal{ + windowSize: windowSize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { @@ -103,7 +115,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend - c.cacheCtx = backend.NewContext() } func (c *Causal) SetConfig(config ml.CacheConfig) { @@ -115,7 +126,9 @@ func (c *Causal) SetConfig(config ml.CacheConfig) { } func (c *Causal) Close() { - c.cacheCtx.Close() + for _, ctx := range c.ctxs { + ctx.Close() + } } func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { @@ -239,13 +252,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te } func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { - for i := range c.keys { - if c.keys[i] == nil { + for i, key := range c.keys { + if key == nil { continue } - key := c.keys[i] - kHeadDim := key.Dim(0) numKVHeads := key.Dim(1) rowSize := key.Stride(2) @@ -305,7 +316,7 @@ func (c *Causal) defrag() { layers++ } - maxMoves := ctx.MaxTensors() / (6 * layers) + maxMoves := ctx.MaxGraphNodes() / (6 * layers) moves := 0 var pendingSrc, pendingDst, pendingLen int @@ -377,11 +388,6 @@ func (c *Causal) defrag() { } func (c *Causal) SetLayer(layer int) { - if layer >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) - } - c.curLayer = layer } @@ -433,13 +439,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) } - if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + if _, ok := c.ctxs[c.curLayer]; !ok { + c.ctxs[c.curLayer] = c.backend.NewContext() + } + if _, ok := c.keys[c.curLayer]; !ok { + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + } + + if _, ok := c.values[c.curLayer]; !ok { if c.config.PermutedV { - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) } else { - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) } } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 39b4cdfb6..6a24e867e 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -35,13 +35,17 @@ type EncoderCache struct { encoderPos int32 // ** cache data storage ** - - cacheCtx ml.Context - keys, values []ml.Tensor + backend ml.Backend + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor } func NewEncoderCache() *EncoderCache { - return &EncoderCache{} + return &EncoderCache{ + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { @@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) } - c.cacheCtx = backend.NewContext() + c.backend = backend } func (c *EncoderCache) SetConfig(config ml.CacheConfig) { @@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) { } func (c *EncoderCache) Close() { - c.cacheCtx.Close() + for _, ctx := range c.ctxs { + ctx.Close() + } } func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { @@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in } func (c *EncoderCache) SetLayer(layer int) { - if layer >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) - } - c.curLayer = layer } @@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { value = value.Permute(ctx, 1, 2, 0, 3) } - if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...) - c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...) + if _, ok := c.ctxs[c.curLayer]; !ok { + c.ctxs[c.curLayer] = c.backend.NewContext() + } + + if _, ok := c.keys[c.curLayer]; !ok { + c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...) + } + + if _, ok := c.values[c.curLayer]; !ok { + c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...) } ctx.Forward( diff --git a/ml/backend.go b/ml/backend.go index 3ef8a1ac2..1eeb635b7 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -99,7 +99,7 @@ type Context interface { Forward(...Tensor) Context Compute(...Tensor) - MaxTensors() int + MaxGraphNodes() int Close() } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index e909f53cf..1a2722567 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -339,14 +339,15 @@ func (b *Backend) Get(name string) ml.Tensor { } func (b *Backend) NewContext() ml.Context { - maxTensors := max(8192, len(b.meta.Tensors().Items())*5) + maxGraphNodes := max(8192, len(b.meta.Tensors().Items())*5) return &Context{ b: b, - maxTensors: maxTensors, ctx: C.ggml_init(C.struct_ggml_init_params{ - mem_size: C.size_t(maxTensors)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxTensors), false), + mem_size: C.size_t(maxGraphNodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxGraphNodes), false), no_alloc: true, }), + backend: C.ggml_backend_sched_get_backend(b.sched, 0), + maxGraphNodes: maxGraphNodes, } } @@ -363,13 +364,14 @@ type Context struct { ctx *C.struct_ggml_context graph *C.struct_ggml_cgraph + backend *C.struct_ggml_backend - maxTensors int + maxGraphNodes int } func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { if c.graph == nil { - c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxTensors), false) + c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false) } for _, tensor := range tensors { @@ -399,8 +401,8 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } -func (c *Context) MaxTensors() int { - return c.maxTensors +func (c *Context) MaxGraphNodes() int { + return c.maxGraphNodes } func shapeToGGML(shape []int) *C.int64_t { @@ -435,7 +437,7 @@ func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor { panic("unsupported dtype") } - b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t)) + b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_set_input(t) return &Tensor{b: ctx.b, t: t} @@ -469,7 +471,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u } t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape)) - b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t)) + b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) C.ggml_set_input(t) @@ -484,8 +486,8 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return fromSlice(c, s, shape, C.GGML_TYPE_I32) } -func (c *Context) Close() { - if c != nil { +func (c Context) Close() { + if c.ctx != nil { C.ggml_free(c.ctx) } }