models: Prune unused outputs earlier in the forward pass

Currently Rows is called as the last step in a model computation
to get the values for the output tokens. However, if we move it
earlier in the process then we can trim out computations that
never get used. This is similar to how models are defined in
llama.cpp.

Changing the model definition in this way improves token generation
performance by approximately 8%.
This commit is contained in:
Jesse Gross 2025-02-18 17:16:57 -08:00 committed by Jesse Gross
parent e5bcc51ae1
commit 5c5535c064
3 changed files with 46 additions and 23 deletions

View File

@ -120,11 +120,19 @@ type Layer struct {
MLP *MLP MLP *MLP
} }
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return hiddenState.Rows(ctx, outputs), nil hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
} }
func init() { func init() {

View File

@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
// TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return hiddenState.Rows(ctx, outputs), nil // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
} }
func init() { func init() {

View File

@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct {
MLP *TextMLP MLP *TextMLP
} }
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts) hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct {
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
} }
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
} }
type TextDecoderLayer interface { type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
} }
type TextDecoder struct { type TextDecoder struct {
Layers []TextDecoderLayer Layers []TextDecoderLayer
} }
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers { for i, layer := range d.Layers {
layerType := selfAttentionLayer layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) { if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
cache.SetLayerType(layerType) cache.SetLayerType(layerType)
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) var lastLayerOutputs ml.Tensor
if i == len(d.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
} }
} }
@ -205,9 +218,9 @@ type TextModel struct {
*TextModelOptions *TextModelOptions
} }
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState) return m.Output.Forward(ctx, hiddenState)
} }