restructure

image processing

Update model.go

Update model.go

Update model.go

no projector

no projector

vision model scaffold

...

...

wip

...

rebase

fix patch merger

tidy

...

Update model_vision.go

server: do not attempt to parse offset file as gguf

This logic was causing issues for me when importing a gguf that had some padding at the end of the file. The valid gguf would be read, but then it would try to read the offset as a different gguf file. This does not seem right.

Update process_image_test.go

apply norm

prompt processing

prompt processing

fix post tokenize

fix gguf padding + populate the split patch embeddings

...

...

another shot at patch embeddings

...

patch embedding

Update model_vision.go

split pixels
This commit is contained in:
Bruce MacDonald 2025-04-02 10:41:51 -07:00
parent 198b1e6db9
commit c1f9bcb4dd
17 changed files with 1194 additions and 208 deletions

View File

@ -189,6 +189,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":
conv = &qwen2Model{} conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25vlModel{}
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "CohereForCausalLM": case "CohereForCausalLM":

188
convert/convert_qwen25vl.go Normal file
View File

@ -0,0 +1,188 @@
package convert
import (
"bytes"
"encoding/binary"
"io"
"log/slog"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
)
type qwen25vlModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
HiddenLayers uint32 `json:"num_hidden_layers"`
RopeTheta float32 `json:"rope_theta"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
VisionModel struct {
PatchSize uint32 `json:"patch_size"`
//HeadDim uint32 `json:"num_heads"`
//RopeTheta float32 `json:"rope_theta"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
WindowSize uint32 `json:"window_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25vlModel)(nil)
func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
kv["qwen25vl.block_count"] = q.HiddenLayers
kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings
kv["qwen25vl.embedding_length"] = q.HiddenSize
kv["qwen25vl.feed_forward_length"] = q.IntermediateSize
kv["qwen25vl.attention.head_count"] = q.NumAttentionHeads
kv["qwen25vl.attention.head_count_kv"] = q.NumKeyValueHeads
kv["qwen25vl.rope.freq_base"] = q.RopeTheta
kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
return kv
}
func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
// var buf bytes.Buffer
// if _, err := t.WriteTo(&buf); err != nil {
// panic(err)
// }
// newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
// out = append(out, newTensors...)
// } else if strings.HasPrefix(t.Name(), "v.blk.") {
// skip
} else {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25vlModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"visual.blocks", "v.blk",
"input_layernorm", "attn_norm",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.q_proj", "attn_q",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
"model.norm", "output_norm",
}
}
func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor {
slog.Debug("patch stuff", "kind", kind, "shape", shape)
if kind != tensorKindF16 {
panic("tensor is of wrong type")
}
if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) {
panic("wrong sized tensor")
}
// determine the size of the tensor based on its shape
shapeToSize := func(s []int) int {
r := 1
for _, n := range s {
r *= int(n)
}
return r
}
// tensor.WithShape() wants []int
intShape := make([]int, len(shape))
for i, v := range shape {
intShape[i] = int(v)
}
u16s := make([]uint16, shapeToSize(intShape))
if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil {
panic("bad read")
}
f32s := make([]float32, len(u16s))
for i := range u16s {
f32s[i] = float16.Frombits(u16s[i]).Float32()
}
newTensors := []ggml.Tensor{}
getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed {
slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape)
n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s))
t, err := n.Slice(s...)
if err != nil {
panic(err)
}
ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0)
if err != nil {
panic(err)
}
slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2])
var f16s patchEmbed
for _, row := range ts {
for _, col := range row {
f16s = append(f16s, float16.Fromfloat32(col).Bits())
}
}
return f16s
}
p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil})
newTensors = append(newTensors, ggml.Tensor{
Name: "v.patch_embed.0.weight",
Kind: kind,
Shape: append(shape[:2], shape[3:]...),
WriterTo: p,
})
p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil})
newTensors = append(newTensors, ggml.Tensor{
Name: "v.patch_embed.1.weight",
Kind: kind,
Shape: append(shape[:2], shape[3:]...),
WriterTo: p,
})
return newTensors
}
type patchEmbed []uint16
func (t patchEmbed) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, t)
return 0, err
}

View File

@ -650,5 +650,9 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
} }
func ggufPadding(offset, align int64) int64 { func ggufPadding(offset, align int64) int64 {
// if we already fit perfectly onto a 16 byte boundary, don't bother padding
if ((align-offset%align)%align)%16 == 0 {
return 0
}
return (align - offset%align) % align return (align - offset%align) % align
} }

View File

@ -617,10 +617,16 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, co
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@ -193,6 +193,7 @@ type Tensor interface {
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor

View File

@ -1064,15 +1064,6 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
} }
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
// RoPE applies Rotary Position Embeddings to the tensor // RoPE applies Rotary Position Embeddings to the tensor
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
@ -1088,21 +1079,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
} }
// Map Go RopeType to C implementation constants
var ropeTypeC C.int
switch config.Type {
case ml.RopeTypeNormal:
ropeTypeC = ropeTypeNorm
case ml.RopeTypeNeox:
ropeTypeC = ropeTypeNeox
case ml.RopeTypeMRoPE:
ropeTypeC = ropeTypeMrope
case ml.RopeTypeVision:
ropeTypeC = ropeTypeVision
default:
ropeTypeC = ropeTypeNorm
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
@ -1111,7 +1087,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
positionIDs.(*Tensor).t, positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(config.Dim), C.int(config.Dim),
ropeTypeC, ropeTypeToC(config.Type),
C.int(config.YarnCtxTrain), C.int(config.YarnCtxTrain),
C.float(config.Base), C.float(config.Base),
C.float(config.Scale), C.float(config.Scale),
@ -1129,6 +1105,60 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
} }
} }
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(config.Dim),
(*C.int)(unsafe.Pointer(&sections[0])),
ropeTypeToC(config.Type),
C.int(config.YarnCtxTrain),
C.float(config.Base),
C.float(config.Scale),
C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
),
}
}
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func ropeTypeToC(ropeType ml.RopeType) C.int {
switch ropeType {
case ml.RopeTypeNormal:
return ropeTypeNorm
case ml.RopeTypeNeox:
return ropeTypeNeox
case ml.RopeTypeMRoPE:
return ropeTypeMrope
case ml.RopeTypeVision:
return ropeTypeVision
default:
return ropeTypeNorm
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{

View File

@ -0,0 +1,51 @@
package model
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/ml"
)
func setup(t *testing.T) ml.Backend {
home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
models := filepath.Join(home, ".ollama", "models")
b, err := New(context.TODO(), filepath.Join(models, "blobs", "sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"), ml.BackendParams{NumGPULayers: 99})
if err != nil {
t.Fatal(err)
}
return b
}
func TestUnfoldConv(t *testing.T) {
b := setup(t)
ctx := b.NewContext().Input()
t.Cleanup(func() { ctx.Close() })
tiles, channels, height, width := 5, 3, 336, 336
patchSize := 14
tt := ctx.Arange(0, float32(tiles*channels*height*width), 1, ml.DTypeF32).Reshape(ctx, width, height, channels, tiles)
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
kernel := ctx.Empty(ml.DTypeF32, patchSize, patchSize, channels)
t.Log("kernel", kernel.Shape())
t.Log(ml.Dump(ctx, kernel))
tt = kernel.IM2Col(ctx, tt, patchSize, patchSize, 0, 0, 1, 1)
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
tt = tt.Reshape(ctx, tt.Dim(0), tt.Dim(1)*tt.Dim(2), tt.Dim(3))
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
}

View File

@ -57,7 +57,7 @@ func newTextModel(c fs.Config) *TextModel {
}, },
), ),
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{ TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),

View File

@ -17,6 +17,7 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
ropeConfig ml.RoPEConfig
} }
type TextModel struct { type TextModel struct {
@ -40,7 +41,6 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim headDim := opts.headDim
if headDim == 0 { if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads headDim = opts.hiddenSize / opts.numHeads
@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil
} }
type MLP struct { type MLP struct {
@ -167,9 +167,13 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeConfig: ml.RoPEConfig{
ropeScale: c.Float("rope.freq_scale", 1), Base: c.Float("rope.freq_base", 10000.0),
ropeDim: c.Uint("rope.dimension_count"), Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
}, },
} }

View File

@ -1,10 +1,11 @@
package qwen25vl package qwen25vl
import ( import (
"bytes"
"fmt" "fmt"
"math" "image"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -12,147 +13,151 @@ import (
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
type Options struct {
ctxLen, hiddenSize, numHeads, numKVHeads int
eps float32
ropeConfig ml.RoPEConfig
}
type Model struct { type Model struct {
model.Base model.Base
model.BytePairEncoding *TextModel
*VisionModel `gguf:"v,vision"`
*PatchMerger `gguf:"mm"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"` ImageProcessor
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
} }
func New(c ml.Config) (model.Model, error) { // Implement MultimodalProcessor interface
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { var _ model.MultimodalProcessor = (*Model)(nil)
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
type PatchMerger struct {
MLPLayer1 *nn.Linear `gguf:"0"`
MLPLayer2 *nn.Linear `gguf:"2"`
}
// Forward computes patch merging for the vision model
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Get dimensions
hiddenSize := visionOutputs.Dim(0)
numPositions := visionOutputs.Dim(1)
batchSize := visionOutputs.Dim(2)
reshaped := visionOutputs.Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
// Apply first linear layer (mm_0_w, mm_0_b)
hidden := pm.MLPLayer1.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
// Apply second linear layer (mm_1_w, mm_1_b)
output := pm.MLPLayer2.Forward(ctx, activated)
return output
}
func New(c fs.Config) (model.Model, error) {
m := &Model{
TextModel: NewTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
} }
m := Model{ m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), return m, nil
&model.Vocabulary{ }
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), type imageFeatures struct {
Merges: c.Strings("tokenizer.ggml.merges"), Tensor ml.Tensor
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), GridT int
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), GridH int
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), GridW int
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), }
},
), func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
Layers: make([]Layer, c.Uint("block_count")), if len(m.VisionModel.Layers) == 0 {
Options: &Options{ return nil, model.ErrNoVisionModel
ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count", 128),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 32768))),
},
},
} }
m.Cache = kvcache.NewCausalCache(m.Shift) image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return &m, nil return nil, err
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
} }
hiddenState = hiddenState.Add(ctx, residual) f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
// Feed-forward branch with residual connection if err != nil {
residual = hiddenState return nil, err
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) }
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual) // Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := gridT * gridH * gridW
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
return &imageFeatures{
Tensor: visionOutputs,
GridT: gridT,
GridH: gridH,
GridW: gridW,
}, nil
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
// Get image token IDs from config
imageToken := 151655
visionStartToken := 151652
visionEndToken := 151653
// Get merge size from config
mergeSize := m.ImageProcessor.mergeSize
for _, inp := range inputs {
if inp.Multimodal == nil {
// If not a multimodal input, add it to the result unchanged
result = append(result, inp)
} else {
// This is an image token with multimodal data
features := inp.Multimodal.(*imageFeatures)
// Get grid dimensions from the features
gridT := features.GridT
gridH := features.GridH
gridW := features.GridW
// Calculate tokens per grid based on grid dimensions
mergeLength := mergeSize * mergeSize
gridProduct := gridT * gridH * gridW
tokensPerGrid := gridProduct / mergeLength
// First add the vision start token
result = append(result, input.Input{Token: int32(visionStartToken)})
// Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{
Token: int32(imageToken),
Multimodal: features.Tensor,
MultimodalHash: inp.MultimodalHash,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
for range tokensPerGrid - 1 {
result = append(result, input.Input{Token: int32(imageToken)})
}
result = append(result, input.Input{Token: int32(visionEndToken)})
}
}
return result, nil
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Convert input tokens and positions to tensors
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,25 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return nil, err return nil, err
} }
// Initial token embedding return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
// Process through transformer layers
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
} }
func init() { func init() {
model.Register("qwen25vl", New)
model.Register("qwen2vl", New) model.Register("qwen2vl", New)
} }

View File

@ -0,0 +1,59 @@
package qwen25vl
import (
"testing"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/model/input"
)
func TestPostTokenize(t *testing.T) {
// Set up test inputs
model := &Model{}
mockHash := uint64(12345678)
inputs := []input.Input{
{Token: 123}, // Regular token
{Token: 456}, // Regular token
{Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token
{Token: 789}, // Regular token
}
// Run the function being tested
result, err := model.PostTokenize(inputs)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
// Verify the actual length first
expectedLength := 21
if len(result) != expectedLength {
t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength)
}
// Check key positions only
checkPositions := map[int]int32{
0: 123, // First regular token
1: 456, // Second regular token
2: 151652, // Vision start token
4: 151655, // First placeholder token
19: 151653, // Vision end token
20: 789, // Final regular token
}
for pos, expectedToken := range checkPositions {
if pos >= len(result) {
t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result))
continue
}
if result[pos].Token != expectedToken {
t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token)
}
}
// Check multimodal data is preserved
if result[3].MultimodalHash != mockHash {
t.Errorf("Multimodal hash not preserved: got %d, expected %d",
result[3].MultimodalHash, mockHash)
}
}

View File

@ -0,0 +1,165 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
ctxLen, hiddenSize, numHeads, numKVHeads int
eps float32
ropeConfig ml.RoPEConfig
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func NewTextModel(c fs.Config) *TextModel {
m := TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count", 128),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))),
},
},
}
return &m
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
// Process through transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}

View File

@ -0,0 +1,260 @@
package qwen25vl
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
// VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
// Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
config := ml.RoPEConfig{
Dim: uint32(opts.headDim / 2),
Type: ml.RopeTypeMRoPE,
Base: opts.ropeTheta,
Scale: 1.0,
YarnConfig: ml.DefaultYarnConfig(128000),
}
query = query.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
key = key.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
attention := nn.Attention(ctx, query, key, value, scale, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the MLP for the Qwen vision model
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
// Forward computes the MLP for the vision model
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using GEGLU activation: (Gate * Up) * GELU(Gate)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates)
}
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
Norm2 *nn.RMSNorm `gguf:"ln2"`
SelfAttention *VisionSelfAttention
MLP *VisionMLP
}
// Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positionIDs, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options for the Qwen vision model
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeTheta float32
outHiddenSize int
}
type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
shape := pixelValues.Shape()
numChannels := 3
temporalPatchSize := 2
embedDim := 1280
numPatches := shape[1] / temporalPatchSize
// Split the input tensor into two temporal slices and process each separately
// First temporal slice (frame 0)
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
reshaped0 := slice0.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Second temporal slice (frame 1)
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
reshaped1 := slice1.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Apply the appropriate convolution to each temporal slice
// PatchConv0 corresponds to weights for temporal frame 0
// PatchConv1 corresponds to weights for temporal frame 1
s0, s1 := patchSize, patchSize // Use full stride as in original
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
// Add the outputs from the two temporal convolutions
combined := output0.Add(ctx, output1)
// Reshape to required output dimensions
result := combined.Reshape(ctx, embedDim, numPatches)
fmt.Println(ml.Dump(ctx, result))
return result
}
// VisionPatchMerger implements patch merging for the Qwen vision model
type VisionPatchMerger struct {
LNQ *nn.RMSNorm `gguf:"ln_q"`
MLP *nn.Linear `gguf:"mlp"`
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contextDim, spatialMergeSize int) ml.Tensor {
hiddenSize := contextDim * (spatialMergeSize * spatialMergeSize)
// Normalize and reshape
x = pm.LNQ.Forward(ctx, x, 1e-6)
x = x.Reshape(ctx, -1, hiddenSize)
// Apply MLP for merging
x = pm.MLP.Forward(ctx, x)
return x
}
// VisionModel implements the Qwen vision model
type VisionModel struct {
PatchEmbedding *PatchEmbedding
Layers []VisionEncoderLayer `gguf:"blk"`
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
PatchMerger *VisionPatchMerger `gguf:"patch_merger"`
*VisionModelOptions
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
// Calculate position IDs for 2D RoPE
numPatchesH := pixelValues.Dim(0) / m.patchSize
numPatchesW := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesH * numPatchesW
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
positions := make([]int32, numPatches*4)
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
idx := h*numPatchesW + w
// For each position, store both h and w coordinates twice
// This matches the pattern seen in the C++ implementation
positions[idx*4] = int32(h) // y coordinate
positions[idx*4+1] = int32(w) // x coordinate
positions[idx*4+2] = int32(h) // y coordinate (repeated)
positions[idx*4+3] = int32(w) // x coordinate (repeated)
}
}
// Create the position IDs tensor with correct dimensions
positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4)
if err != nil {
panic(err)
}
// Apply encoder layers
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
}
hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
return hiddenStates
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
ropeTheta := c.Float("vision.rope_theta", 10000.0) // not set
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
numHeads := int(c.Uint("vision.attention.head_count", 16))
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)),
imageSize: int(c.Uint("vision.image_size", 560)),
patchSize: patchSize,
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: ropeTheta,
outHiddenSize: outHiddenSize,
},
}
}

View File

@ -0,0 +1,196 @@
package qwen25vl
import (
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
type ImageProcessor struct {
imageSize int
numChannels int
patchSize int
temporalPatchSize int
mergeSize int
minPixels int
maxPixels int
factor int
rescaleFactor float32
imageMean []float32
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 560)),
numChannels: 3,
patchSize: patchSize,
temporalPatchSize: 2,
mergeSize: mergeSize,
minPixels: 56 * 56,
maxPixels: 28 * 28 * 4 * 1280,
factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0,
imageMean: []float32{0.48145466, 0.4578275, 0.40821073},
imageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// SmartResize implements the smart resize algorithm
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
if height < factor || width < factor {
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
} else if float64(max(height, width))/float64(min(height, width)) > 200 {
aspectRatio := float64(max(height, width)) / float64(min(height, width))
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %f", aspectRatio))
}
round := func(x float64) int {
return int(math.Round(x))
}
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
if hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize(
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions
gridH := resizedHeight / p.patchSize
gridW := resizedWidth / p.patchSize
gridT := 1 // For single images, temporal dimension is 1
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
if err != nil {
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
}
// Return patches and grid dimensions
return patches, gridT, gridH, gridW, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := gridT * gridH * gridW
patchDim := channels * temporalPatchSize * patchSize * patchSize
// Create output tensor
result := make([]float32, numPatches*patchDim)
// Instead of the complex 9D reshape+transpose, directly extract patches
// in the format expected by the forward pass
patchIndex := 0
for t := 0; t < gridT; t++ {
// For each patch in the grid
for h := 0; h < gridH; h += mergeSize {
for w := 0; w < gridW; w += mergeSize {
// Handle the 2x2 merged patches
for mh := 0; mh < mergeSize; mh++ {
for mw := 0; mw < mergeSize; mw++ {
// For each pixel in the patch
for py := 0; py < patchSize; py++ {
for px := 0; px < patchSize; px++ {
// Calculate source coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// For each channel
for c := 0; c < channels; c++ {
// Channel-first format (CHW)
srcIdx := c*height*width + y*width + x
// Calculate destination index based on the expected layout
// This is the key part that matches what the model expects
dstIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // temporal dim
(py * patchSize) +
px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
// Handle temporal dimension padding (if needed)
for tp := 1; tp < temporalPatchSize; tp++ {
for py := 0; py < patchSize; py++ {
for px := 0; px < patchSize; px++ {
for c := 0; c < channels; c++ {
srcIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // first temporal frame
(py * patchSize) +
px
dstIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(tp * patchSize * patchSize) + // current temporal frame
(py * patchSize) +
px
if srcIdx < len(result) && dstIdx < len(result) {
result[dstIdx] = result[srcIdx] // Copy from first frame
}
}
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@ -0,0 +1,47 @@
package qwen25vl
import (
"image"
_ "image/jpeg" // Register JPEG decoder
"testing"
)
func TestSmartResize(t *testing.T) {
type smartResizeCase struct {
TestImage image.Image
Expected image.Point
}
// Create an image processor with default values
processor := ImageProcessor{
imageSize: 560, // Example value
numChannels: 3,
factor: 28,
minPixels: 56 * 56,
maxPixels: 14 * 14 * 4 * 1280,
}
cases := []smartResizeCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 1024)),
Expected: image.Point{980, 980},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
Expected: image.Point{1036, 756},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
Expected: image.Point{980, 980},
},
}
for _, c := range cases {
b := c.TestImage.Bounds().Max
x, y := processor.SmartResize(b.X, b.Y)
actual := image.Point{x, y}
if actual != c.Expected {
t.Errorf("expected: %v, actual: %v", c.Expected, actual)
}
}
}

View File

@ -14,7 +14,7 @@ import (
const ( const (
DefaultFactor = 28 DefaultFactor = 28
DefaultMinPixels = 56 * 56 DefaultMinPixels = 56 * 56
DefaultMaxPixels = 14 * 14 * 4 * 1280 DefaultMaxPixels = 14 * 14 * 4 * 1280 // TODO: might need to change
) )
// smartResize calculates the size of the image to resize to based on the // smartResize calculates the size of the image to resize to based on the

View File

@ -514,35 +514,18 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} }
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}
var layer Layer
if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
offset = n
} }
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, 0, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
return detectChatTemplate(layers) return detectChatTemplate(layers)
} }