kvcache: Support non-causal attention
Models can disable causality for all or part of their processing while continuing to store data in the KV cache.
This commit is contained in:
parent
0daaaef8c9
commit
6da8b6a879
@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
Capacity int32
|
||||||
|
causal bool
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
// config controls mostly backend-specific optimizations
|
// config controls mostly backend-specific optimizations
|
||||||
@ -42,6 +43,12 @@ type Causal struct {
|
|||||||
// locations in the cache that are needed for this batch
|
// locations in the cache that are needed for this batch
|
||||||
curCellRange cellRange
|
curCellRange cellRange
|
||||||
|
|
||||||
|
// curSequences is the sequences corresponding to this pass's entries in the cache
|
||||||
|
curSequences []int
|
||||||
|
|
||||||
|
// curPositions is the positions corresponding to this pass's entries in the cache
|
||||||
|
curPositions []int32
|
||||||
|
|
||||||
// ** cache metadata **
|
// ** cache metadata **
|
||||||
|
|
||||||
// for each possible location in the cache, stores the position and set of sequences
|
// for each possible location in the cache, stores the position and set of sequences
|
||||||
@ -71,6 +78,7 @@ type cellRange struct {
|
|||||||
|
|
||||||
func NewCausalCache(shift shiftFn) *Causal {
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
|
causal: true,
|
||||||
windowSize: math.MaxInt32,
|
windowSize: math.MaxInt32,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
@ -81,6 +89,7 @@ func NewCausalCache(shift shiftFn) *Causal {
|
|||||||
|
|
||||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
|
causal: true,
|
||||||
windowSize: windowSize,
|
windowSize: windowSize,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
@ -133,6 +142,8 @@ func (c *Causal) Close() {
|
|||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
c.curBatchSize = len(positions)
|
c.curBatchSize = len(positions)
|
||||||
|
c.curSequences = seqs
|
||||||
|
c.curPositions = positions
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
@ -171,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
|
|||||||
c.cellRanges[seq] = seqRange
|
c.cellRanges[seq] = seqRange
|
||||||
}
|
}
|
||||||
|
|
||||||
c.curMask, err = c.buildMask(ctx, positions, seqs)
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -212,7 +223,7 @@ func roundUp(length, pad int) int {
|
|||||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
// Align and pad the two dimensions as required by the backend
|
// Align and pad the two dimensions as required by the backend
|
||||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
|
|
||||||
@ -224,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
|||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
c.cells[j].pos < positions[i]-c.windowSize {
|
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -391,6 +403,26 @@ func (c *Causal) SetLayer(layer int) {
|
|||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
||||||
|
// This state carries over to future forward passes. The default value is true.
|
||||||
|
//
|
||||||
|
// ctx may be set to nil if this is called from outside of a forward pass, for
|
||||||
|
// example, when initializing the cache.
|
||||||
|
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
||||||
|
if c.causal != causal {
|
||||||
|
c.causal = causal
|
||||||
|
|
||||||
|
if ctx != nil {
|
||||||
|
var err error
|
||||||
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// This error should never occur because we have previously built a mask with the same shape
|
||||||
|
panic(fmt.Errorf("SetCausal: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
key := c.keys[c.curLayer]
|
key := c.keys[c.curLayer]
|
||||||
value := c.values[c.curLayer]
|
value := c.values[c.curLayer]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user