use non-causal mask only for image positions
This commit is contained in:
parent
9d2a20a763
commit
e95278932b
@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
Capacity int32
|
||||||
causal bool
|
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
|
opts CausalOptions
|
||||||
|
|
||||||
// config controls mostly backend-specific optimizations
|
// config controls mostly backend-specific optimizations
|
||||||
config *ml.CacheConfig
|
config *ml.CacheConfig
|
||||||
|
|
||||||
@ -79,7 +80,6 @@ type cellRange struct {
|
|||||||
|
|
||||||
func NewCausalCache(shift shiftFn) *Causal {
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
causal: true,
|
|
||||||
windowSize: math.MaxInt32,
|
windowSize: math.MaxInt32,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
|
|||||||
|
|
||||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
causal: true,
|
|
||||||
windowSize: windowSize,
|
windowSize: windowSize,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
|
enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
|
||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
|
|||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CausalOptions struct {
|
||||||
|
// Enabled controls whether the causal mask is generated for a particular position.
|
||||||
|
Except []int32
|
||||||
|
}
|
||||||
|
|
||||||
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
||||||
// This state carries over to future forward passes. The default value is true.
|
// This state carries over to future forward passes. The default value is true.
|
||||||
//
|
//
|
||||||
// ctx may be set to nil if this is called from outside of a forward pass, for
|
// ctx may be set to nil if this is called from outside of a forward pass, for
|
||||||
// example, when initializing the cache.
|
// example, when initializing the cache.
|
||||||
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
if c.causal != causal {
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
c.causal = causal
|
c.opts = opts
|
||||||
|
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
var err error
|
var err error
|
||||||
c.curMask, err = c.buildMask(ctx)
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
|
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
|
||||||
|
|
||||||
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
|
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
|
||||||
causal.SetCausal(ctx, false)
|
except := make([]int32, visionOutputs.Dim(1))
|
||||||
defer causal.SetCausal(ctx, true)
|
for i := 0; i < visionOutputs.Dim(1); i++ {
|
||||||
|
except[i] = int32(offset + i)
|
||||||
|
}
|
||||||
|
|
||||||
|
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user