From c1f9bcb4ddf7d3cb4e69dd0a4ededd96fabac559 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 2 Apr 2025 10:41:51 -0700 Subject: [PATCH] restructure image processing Update model.go Update model.go Update model.go no projector no projector vision model scaffold ... ... wip ... rebase fix patch merger tidy ... Update model_vision.go server: do not attempt to parse offset file as gguf This logic was causing issues for me when importing a gguf that had some padding at the end of the file. The valid gguf would be read, but then it would try to read the offset as a different gguf file. This does not seem right. Update process_image_test.go apply norm prompt processing prompt processing fix post tokenize fix gguf padding + populate the split patch embeddings ... ... another shot at patch embeddings ... patch embedding Update model_vision.go split pixels --- convert/convert.go | 2 + convert/convert_qwen25vl.go | 188 +++++++++++++ fs/ggml/gguf.go | 4 + kvcache/causal_test.go | 8 +- ml/backend.go | 1 + ml/backend/ggml/ggml.go | 80 ++++-- model/model_external_test.go | 51 ++++ model/models/gemma3/model_text.go | 2 +- model/models/mistral3/model_text.go | 18 +- model/models/qwen25vl/model.go | 282 ++++++++++---------- model/models/qwen25vl/model_test.go | 59 ++++ model/models/qwen25vl/model_text.go | 165 ++++++++++++ model/models/qwen25vl/model_vision.go | 260 ++++++++++++++++++ model/models/qwen25vl/process_image.go | 196 ++++++++++++++ model/models/qwen25vl/process_image_test.go | 47 ++++ model/models/qwen2vl/imageproc.go | 2 +- server/create.go | 37 +-- 17 files changed, 1194 insertions(+), 208 deletions(-) create mode 100644 convert/convert_qwen25vl.go create mode 100644 model/model_external_test.go create mode 100644 model/models/qwen25vl/model_test.go create mode 100644 model/models/qwen25vl/model_text.go create mode 100644 model/models/qwen25vl/model_vision.go create mode 100644 model/models/qwen25vl/process_image.go create mode 100644 model/models/qwen25vl/process_image_test.go diff --git a/convert/convert.go b/convert/convert.go index 249ec8077..f4a428479 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -189,6 +189,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &phi3Model{} case "Qwen2ForCausalLM": conv = &qwen2Model{} + case "Qwen2_5_VLForConditionalGeneration": + conv = &qwen25vlModel{} case "BertModel": conv = &bertModel{} case "CohereForCausalLM": diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go new file mode 100644 index 000000000..48fa9f5fd --- /dev/null +++ b/convert/convert_qwen25vl.go @@ -0,0 +1,188 @@ +package convert + +import ( + "bytes" + "encoding/binary" + "io" + "log/slog" + "strings" + + "github.com/ollama/ollama/fs/ggml" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + "github.com/x448/float16" +) + +type qwen25vlModel struct { + ModelParameters + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + HiddenLayers uint32 `json:"num_hidden_layers"` + RopeTheta float32 `json:"rope_theta"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + + VisionModel struct { + PatchSize uint32 `json:"patch_size"` + //HeadDim uint32 `json:"num_heads"` + //RopeTheta float32 `json:"rope_theta"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + WindowSize uint32 `json:"window_size"` + } `json:"vision_config"` +} + +var _ ModelConverter = (*qwen25vlModel)(nil) + +func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV { + kv := q.ModelParameters.KV(t) + kv["general.architecture"] = "qwen25vl" + kv["qwen25vl.block_count"] = q.HiddenLayers + kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings + kv["qwen25vl.embedding_length"] = q.HiddenSize + kv["qwen25vl.feed_forward_length"] = q.IntermediateSize + kv["qwen25vl.attention.head_count"] = q.NumAttentionHeads + kv["qwen25vl.attention.head_count_kv"] = q.NumKeyValueHeads + kv["qwen25vl.rope.freq_base"] = q.RopeTheta + kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS + + kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize + + return kv +} + +func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor { + var out []ggml.Tensor + + for _, t := range ts { + if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") { + // var buf bytes.Buffer + // if _, err := t.WriteTo(&buf); err != nil { + // panic(err) + // } + // newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape()) + // out = append(out, newTensors...) + // } else if strings.HasPrefix(t.Name(), "v.blk.") { + // skip + } else { + out = append(out, ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + } + + return out +} + +func (p *qwen25vlModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "visual.blocks", "v.blk", + "input_layernorm", "attn_norm", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.q_proj", "attn_q", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "post_attention_layernorm", "ffn_norm", + "model.norm", "output_norm", + } +} + +func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor { + slog.Debug("patch stuff", "kind", kind, "shape", shape) + + if kind != tensorKindF16 { + panic("tensor is of wrong type") + } + + if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) { + panic("wrong sized tensor") + } + + // determine the size of the tensor based on its shape + shapeToSize := func(s []int) int { + r := 1 + for _, n := range s { + r *= int(n) + } + return r + } + + // tensor.WithShape() wants []int + intShape := make([]int, len(shape)) + for i, v := range shape { + intShape[i] = int(v) + } + + u16s := make([]uint16, shapeToSize(intShape)) + if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil { + panic("bad read") + } + + f32s := make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = float16.Frombits(u16s[i]).Float32() + } + + newTensors := []ggml.Tensor{} + + getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed { + slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape) + n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s)) + t, err := n.Slice(s...) + if err != nil { + panic(err) + } + + ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0) + if err != nil { + panic(err) + } + + slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2]) + + var f16s patchEmbed + for _, row := range ts { + for _, col := range row { + f16s = append(f16s, float16.Fromfloat32(col).Bits()) + } + } + + return f16s + } + + p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil}) + newTensors = append(newTensors, ggml.Tensor{ + Name: "v.patch_embed.0.weight", + Kind: kind, + Shape: append(shape[:2], shape[3:]...), + WriterTo: p, + }) + + p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil}) + newTensors = append(newTensors, ggml.Tensor{ + Name: "v.patch_embed.1.weight", + Kind: kind, + Shape: append(shape[:2], shape[3:]...), + WriterTo: p, + }) + + return newTensors +} + +type patchEmbed []uint16 + +func (t patchEmbed) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, t) + return 0, err +} diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index 8e75625e0..028785148 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -650,5 +650,9 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error { } func ggufPadding(offset, align int64) int64 { + // if we already fit perfectly onto a 16 byte boundary, don't bother padding + if ((align-offset%align)%align)%16 == 0 { + return 0 + } return (align - offset%align) % align } diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index dc4b81ecc..566724fe3 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -617,10 +617,16 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, co func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { panic("not implemented") } +func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { + panic("not implemented") +} func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } -func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index 8af79f069..e4c90c5e7 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -193,6 +193,7 @@ type Tensor interface { IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor + RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor Sin(ctx Context) Tensor Cos(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 5376d20ae..fa7bee7e7 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1064,15 +1064,6 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -// GGML RoPE types -// These are the types used in the C implementation of RoPE -const ( - ropeTypeNorm C.int = 0 - ropeTypeNeox C.int = 2 - ropeTypeMrope C.int = 8 - ropeTypeVision C.int = 24 -) - // RoPE applies Rotary Position Embeddings to the tensor func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { if ropeFactors == nil { @@ -1088,21 +1079,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing } - // Map Go RopeType to C implementation constants - var ropeTypeC C.int - switch config.Type { - case ml.RopeTypeNormal: - ropeTypeC = ropeTypeNorm - case ml.RopeTypeNeox: - ropeTypeC = ropeTypeNeox - case ml.RopeTypeMRoPE: - ropeTypeC = ropeTypeMrope - case ml.RopeTypeVision: - ropeTypeC = ropeTypeVision - default: - ropeTypeC = ropeTypeNorm - } - return &Tensor{ b: t.b, t: C.ggml_rope_ext( @@ -1111,7 +1087,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(config.Dim), - ropeTypeC, + ropeTypeToC(config.Type), C.int(config.YarnCtxTrain), C.float(config.Base), C.float(config.Scale), @@ -1129,6 +1105,60 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), } } +func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor { + if ropeFactors == nil { + ropeFactors = &Tensor{b: t.b} + } + + dequant := t.t + if C.ggml_is_quantized(t.t._type) { + dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) + } + + return &Tensor{ + b: t.b, + t: C.ggml_rope_multi( + ctx.(*Context).ctx, + dequant, + positionIDs.(*Tensor).t, + ropeFactors.(*Tensor).t, + C.int(config.Dim), + (*C.int)(unsafe.Pointer(§ions[0])), + ropeTypeToC(config.Type), + C.int(config.YarnCtxTrain), + C.float(config.Base), + C.float(config.Scale), + C.float(config.YarnExtFactor), + C.float(config.YarnAttnFactor), + C.float(config.YarnBetaFast), + C.float(config.YarnBetaSlow), + ), + } +} + +// GGML RoPE types +// These are the types used in the C implementation of RoPE +const ( + ropeTypeNorm C.int = 0 + ropeTypeNeox C.int = 2 + ropeTypeMrope C.int = 8 + ropeTypeVision C.int = 24 +) + +func ropeTypeToC(ropeType ml.RopeType) C.int { + switch ropeType { + case ml.RopeTypeNormal: + return ropeTypeNorm + case ml.RopeTypeNeox: + return ropeTypeNeox + case ml.RopeTypeMRoPE: + return ropeTypeMrope + case ml.RopeTypeVision: + return ropeTypeVision + default: + return ropeTypeNorm + } +} func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ diff --git a/model/model_external_test.go b/model/model_external_test.go new file mode 100644 index 000000000..950c522bc --- /dev/null +++ b/model/model_external_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/ollama/ollama/ml" +) + +func setup(t *testing.T) ml.Backend { + home, err := os.UserHomeDir() + if err != nil { + t.Fatal(err) + } + + models := filepath.Join(home, ".ollama", "models") + + b, err := New(context.TODO(), filepath.Join(models, "blobs", "sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"), ml.BackendParams{NumGPULayers: 99}) + if err != nil { + t.Fatal(err) + } + + return b +} + +func TestUnfoldConv(t *testing.T) { + b := setup(t) + ctx := b.NewContext().Input() + t.Cleanup(func() { ctx.Close() }) + + tiles, channels, height, width := 5, 3, 336, 336 + patchSize := 14 + + tt := ctx.Arange(0, float32(tiles*channels*height*width), 1, ml.DTypeF32).Reshape(ctx, width, height, channels, tiles) + t.Log("tt", tt.Shape()) + t.Log(ml.Dump(ctx, tt)) + + kernel := ctx.Empty(ml.DTypeF32, patchSize, patchSize, channels) + t.Log("kernel", kernel.Shape()) + t.Log(ml.Dump(ctx, kernel)) + + tt = kernel.IM2Col(ctx, tt, patchSize, patchSize, 0, 0, 1, 1) + t.Log("tt", tt.Shape()) + t.Log(ml.Dump(ctx, tt)) + + tt = tt.Reshape(ctx, tt.Dim(0), tt.Dim(1)*tt.Dim(2), tt.Dim(3)) + t.Log("tt", tt.Shape()) + t.Log(ml.Dump(ctx, tt)) +} diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index b05acbcad..218d84c02 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -57,7 +57,7 @@ func newTextModel(c fs.Config) *TextModel { }, ), Layers: make([]TextLayer, numBlocks), - TextOptions: &TextOptions{ + TextConfig: &TextConfig{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 1bf72acd8..250f13eee 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -17,6 +17,7 @@ type TextOptions struct { hiddenSize, numHeads, numKVHeads, headDim int eps, ropeBase, ropeScale float32 ropeDim uint32 + ropeConfig ml.RoPEConfig } type TextModel struct { @@ -40,7 +41,6 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(0) headDim := opts.headDim if headDim == 0 { headDim = opts.hiddenSize / opts.numHeads @@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil } type MLP struct { @@ -167,9 +167,13 @@ func NewTextModel(c fs.Config) (*TextModel, error) { numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index e974e2a46..853d86dc9 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -1,10 +1,11 @@ package qwen25vl import ( + "bytes" "fmt" - "math" - "strings" + "image" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -12,147 +13,151 @@ import ( "github.com/ollama/ollama/model/input" ) -type Options struct { - ctxLen, hiddenSize, numHeads, numKVHeads int - eps float32 - ropeConfig ml.RoPEConfig -} - type Model struct { model.Base - model.BytePairEncoding + *TextModel + *VisionModel `gguf:"v,vision"` + *PatchMerger `gguf:"mm"` - TokenEmbedding *nn.Embedding `gguf:"token_embd"` - Layers []Layer `gguf:"blk"` - OutputNorm *nn.RMSNorm `gguf:"output_norm"` - Output *nn.Linear `gguf:"output,alt:token_embd"` - - *Options + ImageProcessor } -func New(c ml.Config) (model.Model, error) { - if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { - return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) +// Implement MultimodalProcessor interface +var _ model.MultimodalProcessor = (*Model)(nil) + +type PatchMerger struct { + MLPLayer1 *nn.Linear `gguf:"0"` + MLPLayer2 *nn.Linear `gguf:"2"` +} + +// Forward computes patch merging for the vision model +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { + // Get dimensions + hiddenSize := visionOutputs.Dim(0) + numPositions := visionOutputs.Dim(1) + batchSize := visionOutputs.Dim(2) + + reshaped := visionOutputs.Reshape(ctx, hiddenSize*4, numPositions/4, batchSize) + + // Apply first linear layer (mm_0_w, mm_0_b) + hidden := pm.MLPLayer1.Forward(ctx, reshaped) + + activated := hidden.GELU(ctx) + + // Apply second linear layer (mm_1_w, mm_1_b) + output := pm.MLPLayer2.Forward(ctx, activated) + + return output +} + +func New(c fs.Config) (model.Model, error) { + m := &Model{ + TextModel: NewTextModel(c), + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), } - m := Model{ - BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Uints("tokenizer.ggml.token_type"), - Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - }, - ), - Layers: make([]Layer, c.Uint("block_count")), - Options: &Options{ - ctxLen: int(c.Uint("context_length")), - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeConfig: ml.RoPEConfig{ - Base: c.Float("rope.freq_base"), - Scale: c.Float("rope.freq_scale", 1), - Dim: c.Uint("rope.dimension_count", 128), - Type: ml.RopeTypeNeox, - YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 32768))), - }, - }, + m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) + + return m, nil +} + +type imageFeatures struct { + Tensor ml.Tensor + GridT int + GridH int + GridW int +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel } - m.Cache = kvcache.NewCausalCache(m.Shift) - - return &m, nil -} - -// SelfAttention implements the multi-head self-attention mechanism -// with separate projections for query, key, value and output transformations -type SelfAttention struct { - Query *nn.Linear `gguf:"attn_q"` - Key *nn.Linear `gguf:"attn_k"` - Value *nn.Linear `gguf:"attn_v"` - Output *nn.Linear `gguf:"attn_output"` - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` -} - -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { - batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads - - q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) - - k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) - - v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) - - return sa.Output.Forward(ctx, kqv) -} - -// Shift applies rotary position embeddings to the key tensor for causal attention caching -func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil -} - -// MLP implements the feed-forward network component with SwiGLU activation -type MLP struct { - Up *nn.Linear `gguf:"ffn_up"` - Down *nn.Linear `gguf:"ffn_down"` - Gate *nn.Linear `gguf:"ffn_gate"` -} - -func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - // Apply SwiGLU activation gating - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) - // Project back to hidden dimension - return mlp.Down.Forward(ctx, hiddenState) -} - -// Layer represents a single transformer layer combining self-attention and feed-forward components -type Layer struct { - AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` - SelfAttention *SelfAttention - MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` - MLP *MLP -} - -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { - // Self-attention branch with residual connection - residual := hiddenState - - hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) - - // In the final layer (outputs != nil), optimize by pruning to just the token positions - // we need logits for. - if outputs != nil { - hiddenState = hiddenState.Rows(ctx, outputs) - residual = residual.Rows(ctx, outputs) + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err } - hiddenState = hiddenState.Add(ctx, residual) - // Feed-forward branch with residual connection - residual = hiddenState - hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.MLP.Forward(ctx, hiddenState, opts) - return hiddenState.Add(ctx, residual) + f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + // Calculate tensor dimensions + patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * + m.ImageProcessor.patchSize * m.ImageProcessor.patchSize + numPatches := gridT * gridH * gridW + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) + if err != nil { + return nil, fmt.Errorf("failed to create tensor from image: %w", err) + } + + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps) + + return &imageFeatures{ + Tensor: visionOutputs, + GridT: gridT, + GridH: gridH, + GridW: gridW, + }, nil +} + +// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + var result []input.Input + + // Get image token IDs from config + imageToken := 151655 + visionStartToken := 151652 + visionEndToken := 151653 + + // Get merge size from config + mergeSize := m.ImageProcessor.mergeSize + + for _, inp := range inputs { + if inp.Multimodal == nil { + // If not a multimodal input, add it to the result unchanged + result = append(result, inp) + } else { + // This is an image token with multimodal data + features := inp.Multimodal.(*imageFeatures) + + // Get grid dimensions from the features + gridT := features.GridT + gridH := features.GridH + gridW := features.GridW + + // Calculate tokens per grid based on grid dimensions + mergeLength := mergeSize * mergeSize + gridProduct := gridT * gridH * gridW + tokensPerGrid := gridProduct / mergeLength + + // First add the vision start token + result = append(result, input.Input{Token: int32(visionStartToken)}) + + // Add the image token with the multimodal tensor data at the first position + result = append(result, input.Input{ + Token: int32(imageToken), + Multimodal: features.Tensor, + MultimodalHash: inp.MultimodalHash, + }) + + // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) + for range tokensPerGrid - 1 { + result = append(result, input.Input{Token: int32(imageToken)}) + } + + result = append(result, input.Input{Token: int32(visionEndToken)}) + } + } + + return result, nil } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - // Convert input tokens and positions to tensors positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -163,25 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - // Initial token embedding - hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) - - // Process through transformer layers - for i, layer := range m.Layers { - m.Cache.SetLayer(i) - - var lastLayerOutputs ml.Tensor - if i == len(m.Layers)-1 { - lastLayerOutputs = outputs - } - - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) - } - - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return m.Output.Forward(ctx, hiddenState), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) } func init() { + model.Register("qwen25vl", New) model.Register("qwen2vl", New) } diff --git a/model/models/qwen25vl/model_test.go b/model/models/qwen25vl/model_test.go new file mode 100644 index 000000000..b9e590a90 --- /dev/null +++ b/model/models/qwen25vl/model_test.go @@ -0,0 +1,59 @@ +package qwen25vl + +import ( + "testing" + + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/model/input" +) + +func TestPostTokenize(t *testing.T) { + // Set up test inputs + model := &Model{} + mockHash := uint64(12345678) + + inputs := []input.Input{ + {Token: 123}, // Regular token + {Token: 456}, // Regular token + {Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token + {Token: 789}, // Regular token + } + + // Run the function being tested + result, err := model.PostTokenize(inputs) + if err != nil { + t.Fatalf("PostTokenize returned error: %v", err) + } + + // Verify the actual length first + expectedLength := 21 + if len(result) != expectedLength { + t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength) + } + + // Check key positions only + checkPositions := map[int]int32{ + 0: 123, // First regular token + 1: 456, // Second regular token + 2: 151652, // Vision start token + 4: 151655, // First placeholder token + 19: 151653, // Vision end token + 20: 789, // Final regular token + } + + for pos, expectedToken := range checkPositions { + if pos >= len(result) { + t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result)) + continue + } + if result[pos].Token != expectedToken { + t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token) + } + } + + // Check multimodal data is preserved + if result[3].MultimodalHash != mockHash { + t.Errorf("Multimodal hash not preserved: got %d, expected %d", + result[3].MultimodalHash, mockHash) + } +} diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go new file mode 100644 index 000000000..3a10e3424 --- /dev/null +++ b/model/models/qwen25vl/model_text.go @@ -0,0 +1,165 @@ +package qwen25vl + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type TextOptions struct { + ctxLen, hiddenSize, numHeads, numKVHeads int + eps float32 + ropeConfig ml.RoPEConfig +} + +type TextModel struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +func NewTextModel(c fs.Config) *TextModel { + m := TextModel{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + Layers: make([]Layer, c.Uint("block_count")), + TextOptions: &TextOptions{ + ctxLen: int(c.Uint("context_length")), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count", 128), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))), + }, + }, + } + + return &m +} + +// SelfAttention implements the multi-head self-attention mechanism +// with separate projections for query, key, value and output transformations +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := opts.hiddenSize / opts.numHeads + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, kqv) +} + +// Shift applies rotary position embeddings to the key tensor for causal attention caching +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil +} + +// MLP implements the feed-forward network component with SwiGLU activation +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { + // Apply SwiGLU activation gating + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + // Project back to hidden dimension + return mlp.Down.Forward(ctx, hiddenState) +} + +// Layer represents a single transformer layer combining self-attention and feed-forward components +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + // Self-attention branch with residual connection + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + // Feed-forward branch with residual connection + residual = hiddenState + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { + // Initial token embedding + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + + // Process through transformer layers + for i, layer := range m.Layers { + cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go new file mode 100644 index 000000000..e8375c0b5 --- /dev/null +++ b/model/models/qwen25vl/model_vision.go @@ -0,0 +1,260 @@ +package qwen25vl + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +var batchSize int = 1 + +// VisionSelfAttention implements self-attention for the Qwen vision model +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_out"` +} + +// Forward computes self-attention for the vision model +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + + config := ml.RoPEConfig{ + Dim: uint32(opts.headDim / 2), + Type: ml.RopeTypeMRoPE, + Base: opts.ropeTheta, + Scale: 1.0, + YarnConfig: ml.DefaultYarnConfig(128000), + } + + query = query.RoPEMulti( + ctx, + positionIDs, + nil, + [4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4}, + config, + ) + key = key.RoPEMulti( + ctx, + positionIDs, + nil, + [4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4}, + config, + ) + + // Scale factor for scaled dot-product attention + scale := 1.0 / math.Sqrt(float64(opts.headDim)) + + attention := nn.Attention(ctx, query, key, value, scale, nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + + return sa.Output.Forward(ctx, attention) +} + +// VisionMLP implements the MLP for the Qwen vision model +type VisionMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +// Forward computes the MLP for the vision model +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { + // Using GEGLU activation: (Gate * Up) * GELU(Gate) + gateOutput := mlp.Gate.Forward(ctx, hiddenStates) + upOutput := mlp.Up.Forward(ctx, hiddenStates) + hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput) + + return mlp.Down.Forward(ctx, hiddenStates) +} + +// VisionEncoderLayer implements an encoder layer for the Qwen vision model +type VisionEncoderLayer struct { + Norm1 *nn.RMSNorm `gguf:"ln1"` + Norm2 *nn.RMSNorm `gguf:"ln2"` + SelfAttention *VisionSelfAttention + MLP *VisionMLP +} + +// Forward computes an encoder layer for the vision model +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positionIDs, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +// VisionModelOptions contains configuration options for the Qwen vision model +type VisionModelOptions struct { + hiddenSize int + numHeads int + headDim int + intermediateSize int + imageSize int + patchSize int + numChannels int + eps float32 + ropeTheta float32 + outHiddenSize int +} + +type PatchEmbedding struct { + PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"` + PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` +} + +func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor { + shape := pixelValues.Shape() + numChannels := 3 + temporalPatchSize := 2 + embedDim := 1280 + numPatches := shape[1] / temporalPatchSize + + // Split the input tensor into two temporal slices and process each separately + // First temporal slice (frame 0) + slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx) + reshaped0 := slice0.Reshape(ctx, + patchSize, // height + patchSize, // width + numChannels, // channels + numPatches) // batch + + // Second temporal slice (frame 1) + slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx) + reshaped1 := slice1.Reshape(ctx, + patchSize, // height + patchSize, // width + numChannels, // channels + numPatches) // batch + + // Apply the appropriate convolution to each temporal slice + // PatchConv0 corresponds to weights for temporal frame 0 + // PatchConv1 corresponds to weights for temporal frame 1 + s0, s1 := patchSize, patchSize // Use full stride as in original + p0, p1 := 0, 0 // padding + d0, d1 := 1, 1 // dilation + + output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1) + output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1) + + // Add the outputs from the two temporal convolutions + combined := output0.Add(ctx, output1) + + // Reshape to required output dimensions + result := combined.Reshape(ctx, embedDim, numPatches) + + fmt.Println(ml.Dump(ctx, result)) + + return result +} + +// VisionPatchMerger implements patch merging for the Qwen vision model +type VisionPatchMerger struct { + LNQ *nn.RMSNorm `gguf:"ln_q"` + MLP *nn.Linear `gguf:"mlp"` +} + +// Forward computes patch merging for the vision model +func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contextDim, spatialMergeSize int) ml.Tensor { + hiddenSize := contextDim * (spatialMergeSize * spatialMergeSize) + + // Normalize and reshape + x = pm.LNQ.Forward(ctx, x, 1e-6) + x = x.Reshape(ctx, -1, hiddenSize) + + // Apply MLP for merging + x = pm.MLP.Forward(ctx, x) + + return x +} + +// VisionModel implements the Qwen vision model +type VisionModel struct { + PatchEmbedding *PatchEmbedding + Layers []VisionEncoderLayer `gguf:"blk"` + PostLayerNorm *nn.LayerNorm `gguf:"post_ln"` + PatchMerger *VisionPatchMerger `gguf:"patch_merger"` + + *VisionModelOptions +} + +// Forward computes the vision model for an input tensor +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + // Calculate position IDs for 2D RoPE + numPatchesH := pixelValues.Dim(0) / m.patchSize + numPatchesW := pixelValues.Dim(1) / m.patchSize + numPatches := numPatchesH * numPatchesW + + // Extract patch embeddings + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize) + + // Create position IDs - for Qwen2VL mRoPE we need 4 values per position + // The format needed is specified in the C++ code as "mrope expecting 4 position ids per token" + positions := make([]int32, numPatches*4) + + for h := 0; h < numPatchesH; h++ { + for w := 0; w < numPatchesW; w++ { + idx := h*numPatchesW + w + // For each position, store both h and w coordinates twice + // This matches the pattern seen in the C++ implementation + positions[idx*4] = int32(h) // y coordinate + positions[idx*4+1] = int32(w) // x coordinate + positions[idx*4+2] = int32(h) // y coordinate (repeated) + positions[idx*4+3] = int32(w) // x coordinate (repeated) + } + } + + // Create the position IDs tensor with correct dimensions + positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4) + if err != nil { + panic(err) + } + + // Apply encoder layers + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions) + } + + hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps) + return hiddenStates +} + +// newVisionModel creates a new instance of the Qwen vision model +func newVisionModel(c fs.Config) *VisionModel { + patchSize := int(c.Uint("vision.patch_size", 14)) + hiddenSize := int(c.Uint("vision.embedding_length", 1280)) + ropeTheta := c.Float("vision.rope_theta", 10000.0) // not set + outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set + numHeads := int(c.Uint("vision.attention.head_count", 16)) + + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: hiddenSize / numHeads, + intermediateSize: int(c.Uint("vision.feed_forward_length", 0)), + imageSize: int(c.Uint("vision.image_size", 560)), + patchSize: patchSize, + numChannels: int(c.Uint("vision.num_channels", 3)), // not set + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), + ropeTheta: ropeTheta, + outHiddenSize: outHiddenSize, + }, + } +} diff --git a/model/models/qwen25vl/process_image.go b/model/models/qwen25vl/process_image.go new file mode 100644 index 000000000..093c28a77 --- /dev/null +++ b/model/models/qwen25vl/process_image.go @@ -0,0 +1,196 @@ +package qwen25vl + +import ( + "fmt" + "image" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +// ImageProcessor contains configuration for the Qwen 2.5 VL image processing +type ImageProcessor struct { + imageSize int + numChannels int + patchSize int + temporalPatchSize int + mergeSize int + minPixels int + maxPixels int + factor int + rescaleFactor float32 + imageMean []float32 + imageStd []float32 +} + +// newImageProcessor creates a new image processor with default values +func newImageProcessor(c fs.Config) ImageProcessor { + + patchSize := int(c.Uint("vision.patch_size", 14)) + mergeSize := int(c.Uint("vision.spatial_merge_size", 2)) + + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 560)), + numChannels: 3, + patchSize: patchSize, + temporalPatchSize: 2, + mergeSize: mergeSize, + minPixels: 56 * 56, + maxPixels: 28 * 28 * 4 * 1280, + factor: patchSize * mergeSize, + rescaleFactor: 1.0 / 255.0, + imageMean: []float32{0.48145466, 0.4578275, 0.40821073}, + imageStd: []float32{0.26862954, 0.26130258, 0.27577711}, + } +} + +// SmartResize implements the smart resize algorithm +func (p *ImageProcessor) SmartResize(height, width int) (int, int) { + factor := p.factor + + if height < factor || width < factor { + panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor)) + } else if float64(max(height, width))/float64(min(height, width)) > 200 { + aspectRatio := float64(max(height, width)) / float64(min(height, width)) + panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %f", aspectRatio)) + } + + round := func(x float64) int { + return int(math.Round(x)) + } + hBar := round(float64(height)/float64(factor)) * factor + wBar := round(float64(width)/float64(factor)) * factor + + if hBar*wBar > p.maxPixels { + beta := math.Sqrt(float64(height*width) / float64(p.maxPixels)) + + hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor + wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor + } else if hBar*wBar < p.minPixels { + beta := math.Sqrt(float64(p.minPixels) / float64(height*width)) + + hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor + wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor + } + + return hBar, wBar +} + +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) { + origWidth := img.Bounds().Dx() + origHeight := img.Bounds().Dy() + + // Calculate smart resize dimensions + resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth) + + // Resize image using existing functions + resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) + + normalizedPixels := imageproc.Normalize( + resizedImg, + [3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]}, + [3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]}, + true, // rescale + true, // channelFirst + ) + + // Calculate grid dimensions + gridH := resizedHeight / p.patchSize + gridW := resizedWidth / p.patchSize + gridT := 1 // For single images, temporal dimension is 1 + + patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT) + if err != nil { + return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err) + } + + // Return patches and grid dimensions + return patches, gridT, gridH, gridW, nil +} + +func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) { + channels := p.numChannels + patchSize := p.patchSize + mergeSize := p.mergeSize + temporalPatchSize := p.temporalPatchSize + + // Calculate output dimensions + numPatches := gridT * gridH * gridW + patchDim := channels * temporalPatchSize * patchSize * patchSize + + // Create output tensor + result := make([]float32, numPatches*patchDim) + + // Instead of the complex 9D reshape+transpose, directly extract patches + // in the format expected by the forward pass + patchIndex := 0 + + for t := 0; t < gridT; t++ { + // For each patch in the grid + for h := 0; h < gridH; h += mergeSize { + for w := 0; w < gridW; w += mergeSize { + // Handle the 2x2 merged patches + for mh := 0; mh < mergeSize; mh++ { + for mw := 0; mw < mergeSize; mw++ { + // For each pixel in the patch + for py := 0; py < patchSize; py++ { + for px := 0; px < patchSize; px++ { + // Calculate source coordinates + y := (h+mh)*patchSize + py + x := (w+mw)*patchSize + px + + // For each channel + for c := 0; c < channels; c++ { + // Channel-first format (CHW) + srcIdx := c*height*width + y*width + x + + // Calculate destination index based on the expected layout + // This is the key part that matches what the model expects + dstIdx := patchIndex*patchDim + + (c * temporalPatchSize * patchSize * patchSize) + + (0 * patchSize * patchSize) + // temporal dim + (py * patchSize) + + px + + if srcIdx < len(pixels) && dstIdx < len(result) { + result[dstIdx] = pixels[srcIdx] + } + } + } + } + + // Handle temporal dimension padding (if needed) + for tp := 1; tp < temporalPatchSize; tp++ { + for py := 0; py < patchSize; py++ { + for px := 0; px < patchSize; px++ { + for c := 0; c < channels; c++ { + srcIdx := patchIndex*patchDim + + (c * temporalPatchSize * patchSize * patchSize) + + (0 * patchSize * patchSize) + // first temporal frame + (py * patchSize) + + px + + dstIdx := patchIndex*patchDim + + (c * temporalPatchSize * patchSize * patchSize) + + (tp * patchSize * patchSize) + // current temporal frame + (py * patchSize) + + px + + if srcIdx < len(result) && dstIdx < len(result) { + result[dstIdx] = result[srcIdx] // Copy from first frame + } + } + } + } + } + + patchIndex++ + } + } + } + } + } + + return result, nil +} diff --git a/model/models/qwen25vl/process_image_test.go b/model/models/qwen25vl/process_image_test.go new file mode 100644 index 000000000..67468d7d6 --- /dev/null +++ b/model/models/qwen25vl/process_image_test.go @@ -0,0 +1,47 @@ +package qwen25vl + +import ( + "image" + _ "image/jpeg" // Register JPEG decoder + "testing" +) + +func TestSmartResize(t *testing.T) { + type smartResizeCase struct { + TestImage image.Image + Expected image.Point + } + + // Create an image processor with default values + processor := ImageProcessor{ + imageSize: 560, // Example value + numChannels: 3, + factor: 28, + minPixels: 56 * 56, + maxPixels: 14 * 14 * 4 * 1280, + } + + cases := []smartResizeCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 1024)), + Expected: image.Point{980, 980}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + Expected: image.Point{1036, 756}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)), + Expected: image.Point{980, 980}, + }, + } + + for _, c := range cases { + b := c.TestImage.Bounds().Max + x, y := processor.SmartResize(b.X, b.Y) + actual := image.Point{x, y} + if actual != c.Expected { + t.Errorf("expected: %v, actual: %v", c.Expected, actual) + } + } +} diff --git a/model/models/qwen2vl/imageproc.go b/model/models/qwen2vl/imageproc.go index 964b39072..82abf7321 100644 --- a/model/models/qwen2vl/imageproc.go +++ b/model/models/qwen2vl/imageproc.go @@ -14,7 +14,7 @@ import ( const ( DefaultFactor = 28 DefaultMinPixels = 56 * 56 - DefaultMaxPixels = 14 * 14 * 4 * 1280 + DefaultMaxPixels = 14 * 14 * 4 * 1280 // TODO: might need to change ) // smartResize calculates the size of the image to resize to based on the diff --git a/server/create.go b/server/create.go index 41c8731cc..810322a79 100644 --- a/server/create.go +++ b/server/create.go @@ -514,35 +514,18 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML } else if err != nil { return nil, err } - - mediatype := "application/vnd.ollama.image.model" - if f.KV().Kind() == "adapter" { - mediatype = "application/vnd.ollama.image.adapter" - } else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" { - mediatype = "application/vnd.ollama.image.projector" - } - - var layer Layer - if digest != "" && n == stat.Size() && offset == 0 { - layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) - if err != nil { - slog.Debug("could not create new layer from layer", "error", err) - return nil, err - } - } - - // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) - if layer.Digest == "" { - layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype) - if err != nil { - return nil, err - } - } - - layers = append(layers, &layerGGML{layer, f}) - offset = n } + // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) + if layer.Digest == "" { + layer, err = NewLayer(io.NewSectionReader(blob, 0, n), mediatype) + if err != nil { + return nil, err + } + } + + layers = append(layers, &layerGGML{layer, f}) + return detectChatTemplate(layers) }