chunked attention
This commit is contained in:
parent
470af8ab89
commit
8bf11b84c1
@ -19,6 +19,7 @@ type llama4Model struct {
|
|||||||
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
||||||
UseQKNorm bool `json:"use_qk_norm"`
|
UseQKNorm bool `json:"use_qk_norm"`
|
||||||
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
||||||
|
AttentionChunkSize uint32 `json:"attention_chunk_size"`
|
||||||
} `json:"text_config"`
|
} `json:"text_config"`
|
||||||
VisionModel struct {
|
VisionModel struct {
|
||||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
@ -51,6 +52,7 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
||||||
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
||||||
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
||||||
|
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
|
||||||
|
|
||||||
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
chunkSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
|
|
||||||
@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: math.MaxInt32,
|
||||||
|
chunkSize: chunkSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
|
@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
|
|||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChunkedAttention(t *testing.T) {
|
||||||
|
cache := NewChunkedAttentionCache(2, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
var b testBackend
|
||||||
|
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
testCache(
|
||||||
|
t, &b, cache,
|
||||||
|
[]testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{
|
||||||
|
0, x, x, x,
|
||||||
|
0, 0, x, x,
|
||||||
|
x, x, 0, x,
|
||||||
|
x, x, 0, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6, 7},
|
||||||
|
inShape: []int{1, 1, 3},
|
||||||
|
seqs: []int{0, 0, 0},
|
||||||
|
pos: []int32{4, 5, 6},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||||
|
expectedShape: []int{1, 1, 7},
|
||||||
|
expectedMask: []float32{
|
||||||
|
x, x, x, x, 0, x, x,
|
||||||
|
x, x, x, x, 0, 0, x,
|
||||||
|
x, x, x, x, x, x, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ThirdBatch",
|
||||||
|
in: []float32{8, 9},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{7, 8},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||||
|
expectedShape: []int{1, 1, 9},
|
||||||
|
expectedMask: []float32{
|
||||||
|
x, x, x, x, x, x, 0, 0, x,
|
||||||
|
x, x, x, x, x, x, x, x, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSequences(t *testing.T) {
|
func TestSequences(t *testing.T) {
|
||||||
backend := &testBackend{}
|
backend := &testBackend{}
|
||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
|
|
||||||
context.Forward(out, mask).Compute(out, mask)
|
context.Forward(out, mask).Compute(out, mask)
|
||||||
|
|
||||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
if !slices.Equal(out.Floats(), test.expected) {
|
||||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||||
|
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||||
|
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -52,8 +52,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewWrapperCache(
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
// TODO: pretend this is chunked attention for now
|
kvcache.NewChunkedAttentionCache(int32(c.Uint("attention.chunk_size")), m.Shift),
|
||||||
kvcache.NewSWACache(8192, m.Shift),
|
|
||||||
kvcache.NewCausalCache(m.Shift),
|
kvcache.NewCausalCache(m.Shift),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user