From 198b1e6db9699ea67f4d51cc34644aba0bbf445d Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 1 Apr 2025 14:09:41 -0700 Subject: [PATCH] text model forward pass --- model/models/models.go | 1 + model/models/qwen25vl/model.go | 187 +++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+) create mode 100644 model/models/qwen25vl/model.go diff --git a/model/models/models.go b/model/models/models.go index 73b4c53a5..133e51761 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,4 +7,5 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/qwen25vl" ) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go new file mode 100644 index 000000000..e974e2a46 --- /dev/null +++ b/model/models/qwen25vl/model.go @@ -0,0 +1,187 @@ +package qwen25vl + +import ( + "fmt" + "math" + "strings" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + ctxLen, hiddenSize, numHeads, numKVHeads int + eps float32 + ropeConfig ml.RoPEConfig +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c ml.Config) (model.Model, error) { + if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { + return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) + } + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + Layers: make([]Layer, c.Uint("block_count")), + Options: &Options{ + ctxLen: int(c.Uint("context_length")), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count", 128), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 32768))), + }, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + + return &m, nil +} + +// SelfAttention implements the multi-head self-attention mechanism +// with separate projections for query, key, value and output transformations +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := opts.hiddenSize / opts.numHeads + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, kqv) +} + +// Shift applies rotary position embeddings to the key tensor for causal attention caching +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil +} + +// MLP implements the feed-forward network component with SwiGLU activation +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + // Apply SwiGLU activation gating + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + // Project back to hidden dimension + return mlp.Down.Forward(ctx, hiddenState) +} + +// Layer represents a single transformer layer combining self-attention and feed-forward components +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + // Self-attention branch with residual connection + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + // Feed-forward branch with residual connection + residual = hiddenState + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + // Convert input tokens and positions to tensors + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + + // Initial token embedding + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + // Process through transformer layers + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} + +func init() { + model.Register("qwen2vl", New) +}