diff --git a/model/models/llama/model.go b/model/models/llama/model.go index b2c5c2c7b..e90631fb8 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -120,11 +120,19 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - - for i, layer := range m.Layers { - m.Cache.SetLayer(i) - hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) - } - - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - hiddenState = m.Output.Forward(ctx, hiddenState) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil } func init() { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index a1460d940..f5521ce5c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - // TODO: attention mask, cross attention mask - hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + // TODO: attention mask, cross attention mask + return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1e48086a3..8ad804cf9 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct { MLP *TextMLP } -func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct { MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` } -func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, } type TextDecoderLayer interface { - Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor + Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor } type TextDecoder struct { Layers []TextDecoderLayer } -func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { layerType := selfAttentionLayer if slices.Contains(opts.crossAttentionLayers, uint32(i)) { @@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr cache.SetLayerType(layerType) if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { - hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) + var lastLayerOutputs ml.Tensor + if i == len(d.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts) } } @@ -205,9 +218,9 @@ type TextModel struct { *TextModelOptions } -func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) - hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) + hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) return m.Output.Forward(ctx, hiddenState) }