kvcache: create cache ctx per layer
each cache layer creates and maintains its own context instead of using a large context for all layers
This commit is contained in:
parent
bfce55db3d
commit
764e199d67
@ -55,8 +55,8 @@ type Causal struct {
|
|||||||
|
|
||||||
shiftFn shiftFn
|
shiftFn shiftFn
|
||||||
backend ml.Backend
|
backend ml.Backend
|
||||||
cacheCtx ml.Context
|
ctxs map[int]ml.Context
|
||||||
keys, values []ml.Tensor
|
keys, values map[int]ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type cacheCell struct {
|
type cacheCell struct {
|
||||||
@ -70,11 +70,23 @@ type cellRange struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCausalCache(shift shiftFn) *Causal {
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
|
return &Causal{
|
||||||
|
windowSize: math.MaxInt32,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{windowSize: windowSize, shiftFn: shift}
|
return &Causal{
|
||||||
|
windowSize: windowSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
@ -103,7 +115,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|||||||
c.cells = make([]cacheCell, c.Capacity)
|
c.cells = make([]cacheCell, c.Capacity)
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
c.cacheCtx = backend.NewContext()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
@ -115,7 +126,9 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Close() {
|
func (c *Causal) Close() {
|
||||||
c.cacheCtx.Close()
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
@ -239,13 +252,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||||
for i := range c.keys {
|
for i, key := range c.keys {
|
||||||
if c.keys[i] == nil {
|
if key == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
key := c.keys[i]
|
|
||||||
|
|
||||||
kHeadDim := key.Dim(0)
|
kHeadDim := key.Dim(0)
|
||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
@ -305,7 +316,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxTensors() / (6 * layers)
|
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@ -377,11 +388,6 @@ func (c *Causal) defrag() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) SetLayer(layer int) {
|
func (c *Causal) SetLayer(layer int) {
|
||||||
if layer >= len(c.keys) {
|
|
||||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
|
||||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,13 +439,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
c.ctxs[c.curLayer] = c.backend.NewContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,13 +35,17 @@ type EncoderCache struct {
|
|||||||
encoderPos int32
|
encoderPos int32
|
||||||
|
|
||||||
// ** cache data storage **
|
// ** cache data storage **
|
||||||
|
backend ml.Backend
|
||||||
cacheCtx ml.Context
|
ctxs map[int]ml.Context
|
||||||
keys, values []ml.Tensor
|
keys, values map[int]ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEncoderCache() *EncoderCache {
|
func NewEncoderCache() *EncoderCache {
|
||||||
return &EncoderCache{}
|
return &EncoderCache{
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
}
|
}
|
||||||
|
|
||||||
c.cacheCtx = backend.NewContext()
|
c.backend = backend
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||||
@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Close() {
|
func (c *EncoderCache) Close() {
|
||||||
c.cacheCtx.Close()
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) SetLayer(layer int) {
|
func (c *EncoderCache) SetLayer(layer int) {
|
||||||
if layer >= len(c.keys) {
|
|
||||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
|
||||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
c.ctxs[c.curLayer] = c.backend.NewContext()
|
||||||
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
|
@ -99,7 +99,7 @@ type Context interface {
|
|||||||
|
|
||||||
Forward(...Tensor) Context
|
Forward(...Tensor) Context
|
||||||
Compute(...Tensor)
|
Compute(...Tensor)
|
||||||
MaxTensors() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,14 +339,15 @@ func (b *Backend) Get(name string) ml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) NewContext() ml.Context {
|
func (b *Backend) NewContext() ml.Context {
|
||||||
maxTensors := max(8192, len(b.meta.Tensors().Items())*5)
|
maxGraphNodes := max(8192, len(b.meta.Tensors().Items())*5)
|
||||||
return &Context{
|
return &Context{
|
||||||
b: b,
|
b: b,
|
||||||
maxTensors: maxTensors,
|
|
||||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||||
mem_size: C.size_t(maxTensors)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxTensors), false),
|
mem_size: C.size_t(maxGraphNodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxGraphNodes), false),
|
||||||
no_alloc: true,
|
no_alloc: true,
|
||||||
}),
|
}),
|
||||||
|
backend: C.ggml_backend_sched_get_backend(b.sched, 0),
|
||||||
|
maxGraphNodes: maxGraphNodes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,13 +364,14 @@ type Context struct {
|
|||||||
|
|
||||||
ctx *C.struct_ggml_context
|
ctx *C.struct_ggml_context
|
||||||
graph *C.struct_ggml_cgraph
|
graph *C.struct_ggml_cgraph
|
||||||
|
backend *C.struct_ggml_backend
|
||||||
|
|
||||||
maxTensors int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||||
if c.graph == nil {
|
if c.graph == nil {
|
||||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxTensors), false)
|
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tensor := range tensors {
|
for _, tensor := range tensors {
|
||||||
@ -399,8 +401,8 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) MaxTensors() int {
|
func (c *Context) MaxGraphNodes() int {
|
||||||
return c.maxTensors
|
return c.maxGraphNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func shapeToGGML(shape []int) *C.int64_t {
|
func shapeToGGML(shape []int) *C.int64_t {
|
||||||
@ -435,7 +437,7 @@ func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
|
|||||||
panic("unsupported dtype")
|
panic("unsupported dtype")
|
||||||
}
|
}
|
||||||
|
|
||||||
b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
|
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||||
C.ggml_set_input(t)
|
C.ggml_set_input(t)
|
||||||
return &Tensor{b: ctx.b, t: t}
|
return &Tensor{b: ctx.b, t: t}
|
||||||
@ -469,7 +471,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
|
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
|
||||||
b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
|
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||||
C.ggml_set_input(t)
|
C.ggml_set_input(t)
|
||||||
@ -484,8 +486,8 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
|||||||
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
|
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Close() {
|
func (c Context) Close() {
|
||||||
if c != nil {
|
if c.ctx != nil {
|
||||||
C.ggml_free(c.ctx)
|
C.ggml_free(c.ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user