Compare commits

...

4 Commits

Author SHA1 Message Date
Bruce MacDonald
2c0300073f standard repqFreq var names 2025-02-20 09:48:03 -08:00
Bruce MacDonald
f93bd92027 model: document qwen2 forward pass 2025-02-18 14:20:42 -08:00
Bruce MacDonald
9dc1fb8a91 model: add new engine support for qwen2 family 2025-02-18 14:20:34 -08:00
Bruce MacDonald
eb086514da ml: let model specify rope configuration
Add support for model-specific RoPE configuration parameters by:

1. Creating a new `RopeConfig` struct to encapsulate all RoPE parameters
2. Adding `RopeType` enum to specify different RoPE variants (Standard/NeoX)
3. Extracting original context length from model config
4. Refactoring `RoPE()` interface to use the new config struct
5. Updating llama and mllama models to use new RoPE configuration

This change allows models to specify their RoPE implementation type and
original context length, which is important for proper position embedding
calculation and model compatibility.
2025-02-14 14:18:51 -08:00
7 changed files with 327 additions and 28 deletions

View File

@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
panic("not implemented")
}

View File

@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) {
return nil, fmt.Errorf("unsupported backend")
}
// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
type RopeType int
const (
RopeTypeStandard RopeType = iota
_ // not yet used
RopeTypeNeoX
)
// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
type RopeConfig struct {
// PositionIDs contains the position indices for each token in the sequence
// These indices are used to calculate the rotary embeddings
PositionIDs Tensor
// RopeFactors is an optional tensor containing pre-computed rotation factors
RopeFactors Tensor
// RopeDim specifies the dimension size for the rotary embeddings
RopeDim uint32
// RopeType indicates which RoPE variant to use (e.g. normal or neox)
RopeType RopeType
// OrigCtxLen stores the original context length the model was trained with
OrigCtxLen int
// RopeBase is the base value used in the frequency calculation
RopeBase float32
// RopeScale is a scaling factor applied to position indices
RopeScale float32
// YaRN parameters can be added here if they need to be configurable
}
type Context interface {
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
@ -75,7 +111,7 @@ type Tensor interface {
Scale(ctx Context, s float64) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, rc RopeConfig) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
const (
ropeTypeNorm C.int = iota
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
if rc.RopeFactors == nil {
rc.RopeFactors = &Tensor{}
}
dequant := t.t
@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
return &Tensor{
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.float(ropeBase),
C.float(ropeScale),
ctx.(*Context).ctx,
dequant,
rc.PositionIDs.(*Tensor).t,
rc.RopeFactors.(*Tensor).t,
C.int(rc.RopeDim),
C.int(rc.RopeType),
C.int(rc.OrigCtxLen),
C.float(rc.RopeBase),
C.float(rc.RopeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast

View File

@ -10,10 +10,10 @@ import (
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ctxLen: int(c.Uint("context_length")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
@ -67,14 +68,23 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positionIDs,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.ctxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, rc)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, rc)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.Options.RopeFactors,
RopeDim: m.Options.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.Options.ctxLen,
RopeBase: m.Options.ropeBase,
RopeScale: m.Options.ropeScale,
},
), nil
}
type MLP struct {

View File

@ -19,14 +19,23 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positions,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.ctxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, rc)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, rc)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.RopeFactors,
RopeDim: m.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.ctxLen,
RopeBase: m.ropeBase,
RopeScale: m.ropeScale,
},
), nil
}
type TextMLP struct {
@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
}

View File

@ -3,4 +3,5 @@ package models
import (
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/qwen2"
)

222
model/models/qwen2/model.go Normal file
View File

@ -0,0 +1,222 @@
package qwen2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
contextLength int
hiddenSize int
numAttnHeads int
numKVHeads int
modelEpsilon float32
ropeFreqBase float32
ropeFreqScale float32
ropeDimensions uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numAttnHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
modelEpsilon: c.Float("attention.layer_norm_rms_epsilon"),
contextLength: int(c.Uint("context_length")),
ropeFreqBase: c.Float("rope.freq_base"),
ropeFreqScale: c.Float("rope.freq_scale", 1),
ropeDimensions: c.Uint("rope.dimension_count", 64),
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return m, nil
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.Options.RopeFactors,
RopeDim: m.Options.ropeDimensions,
RopeType: ml.RopeTypeNeoX,
OrigCtxLen: m.Options.contextLength,
RopeBase: m.Options.ropeFreqBase,
RopeScale: m.Options.ropeFreqScale,
},
), nil
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
// Initialize dimensions and configuration
batchSize := hiddenState.Dim(1)
headDimension := opts.hiddenSize / opts.numAttnHeads
ropeConfig := ml.RopeConfig{
PositionIDs: inputPositions,
RopeFactors: nil,
RopeDim: opts.ropeDimensions,
RopeType: ml.RopeTypeNeoX,
OrigCtxLen: opts.contextLength,
RopeBase: opts.ropeFreqBase,
RopeScale: opts.ropeFreqScale,
}
// Project and reshape query states with rotary embeddings
queryStates := sa.Query.Forward(ctx, hiddenState)
queryStates = queryStates.Reshape(ctx, headDimension, opts.numAttnHeads, batchSize)
queryStates = queryStates.RoPE(ctx, ropeConfig)
// Project and reshape key states with rotary embeddings
keyStates := sa.Key.Forward(ctx, hiddenState)
keyStates = keyStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
keyStates = keyStates.RoPE(ctx, ropeConfig)
// Project and reshape value states
valueStates := sa.Value.Forward(ctx, hiddenState)
valueStates = valueStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
// Update and retrieve from KV cache
cache.Put(ctx, keyStates, valueStates)
keyStates, valueStates, attentionMask := cache.Get(ctx)
// Prepare tensors for attention computation
queryStates = queryStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
keyStates = keyStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
valueStates = valueStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
// Apply scaling and attention mask to scores
attentionScores := keyStates.MulmatFullPrec(ctx, queryStates)
attentionScores = attentionScores.Scale(ctx, 1.0/math.Sqrt(float64(headDimension)))
attentionScores = attentionScores.Add(ctx, attentionMask)
// Compute scaled dot-product attention
attentionProbs := attentionScores.Softmax(ctx)
// Apply attention weights and reshape
weightedStates := valueStates.Mulmat(ctx, attentionProbs)
weightedStates = weightedStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
weightedStates = weightedStates.Reshape(ctx, opts.hiddenSize, batchSize)
// Project to output dimension
return sa.Output.Forward(ctx, weightedStates)
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
// Apply SwiGLU activation gating
gateActivation := mlp.Gate.Forward(ctx, hiddenState).SILU(ctx)
upProjection := mlp.Up.Forward(ctx, hiddenState)
intermediateStates := gateActivation.Mul(ctx, upProjection)
// Project back to hidden dimension
return mlp.Down.Forward(ctx, intermediateStates)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
normalizedAttention := l.AttentionNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
attentionOutput := l.SelfAttention.Forward(ctx, normalizedAttention, positionIDs, cache, opts)
hiddenState = attentionOutput.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
normalizedMLP := l.MLPNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
mlpOutput := l.MLP.Forward(ctx, normalizedMLP, opts)
output := mlpOutput.Add(ctx, residual)
return output
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
// Convert input tokens and positions to tensors
inputTensor, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positionsTensor, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, inputTensor)
// Process through transformer layers
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
hiddenStates = layer.Forward(ctx, hiddenStates, positionsTensor, m.Cache, m.Options)
}
// Final layer normalization and output projection
normalizedOutput := m.OutputNorm.Forward(ctx, hiddenStates, m.modelEpsilon)
logits := m.Output.Forward(ctx, normalizedOutput)
// Extract requested output token positions
outputsTensor, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return logits.Rows(ctx, outputsTensor), nil
}
func init() {
model.Register("qwen2", New)
}