This commit is contained in:
jmorganca 2025-03-22 22:33:39 -07:00
parent caddb1e4cf
commit 8dd2a81f8c
8 changed files with 195 additions and 124 deletions

View File

@ -116,13 +116,16 @@ func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
func (p *mistral3Model) Replacements() []string { func (p *mistral3Model) Replacements() []string {
return []string{ return []string{
// Text model replacements "language_model.model.norm", "output_norm",
"model.layers", "blk", "language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm", "input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm", "post_attention_layernorm", "ffn_norm",
"lm_head", "output", "embed_tokens", "token_embd",
"model.embed_tokens.weight", "token_embd.weight",
"model.norm.weight", "output_norm.weight",
"self_attn.q_proj", "attn_q", "self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k", "self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v", "self_attn.v_proj", "attn_v",
@ -130,50 +133,18 @@ func (p *mistral3Model) Replacements() []string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate", "mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
// Language model replacements "attention.k_proj", "attn_k",
"language_model.model.embed_tokens", "token_embd", "attention.v_proj", "attn_v",
"language_model.model.layers", "blk", "attention.o_proj", "attn_output",
"language_model.model.layers.*.input_layernorm", "attn_norm", "attention_norm", "attn_norm",
"language_model.model.layers.*.self_attn.q_proj", "attn_q", "feed_forward", "mlp",
"language_model.model.layers.*.self_attn.k_proj", "attn_k", "feed_forward.gate_proj", "ffn_gate",
"language_model.model.layers.*.self_attn.v_proj", "attn_v", "feed_forward.down_proj", "ffn_down",
"language_model.model.layers.*.self_attn.o_proj", "attn_output", "feed_forward.up_proj", "ffn_up",
"language_model.model.layers.*.mlp.gate_proj", "ffn_gate",
"language_model.model.layers.*.mlp.down_proj", "ffn_down",
"language_model.model.layers.*.mlp.up_proj", "ffn_up",
"language_model.model.layers.*.post_attention_layernorm", "ffn_norm",
"language_model.lm_head", "output",
"language_model.model.norm", "output_norm",
// Vision model replacements - map to shorter prefixes
"vision_tower", "v",
"multi_modal_projector", "mm", "multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
// Vision transformer blocks - these should be updated accordingly "lm_head", "output",
"vision_tower.transformer.layers", "v.blk",
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
"vision_tower.ln_pre", "v.encoder_norm",
"vision_tower.patch_conv", "v.patch_conv",
"vision_tower.embeddings", "v.embeddings",
// Alternative vision model paths
"vision_model.vision_model.embeddings", "v.embeddings",
"vision_model.vision_model", "v",
"vision_model.layers", "v.blk",
// Multimodal projector components
"multi_modal_projector.patch_merger", "mm.patch_merger",
"multi_modal_projector.norm", "mm.norm",
"multi_modal_projector.linear", "mm.projection",
} }
} }

View File

@ -144,6 +144,9 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor

View File

@ -958,6 +958,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
} }
} }
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(unsafe.Pointer(&sections[0])),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
),
}
}
func (t *Tensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, weight.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,

View File

@ -2186,6 +2186,10 @@ static void ggml_metal_encode_node(
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
if (ne00 != ne10) {
printf("mul_mat, ne00: %d, ne01: %d, ne02: %d, ne03: %d, ne10: %d, ne11: %d, ne12: %d, ne13: %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13);
}
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);

View File

@ -21,8 +21,7 @@ func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds() b := img.Bounds()
le := float64(longestEdge) ratio := math.Max(float64(b.Max.Y)/float64(longestEdge), float64(b.Max.X)/float64(longestEdge))
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
newSize := img.Bounds().Max newSize := img.Bounds().Max
@ -80,17 +79,14 @@ func newImageProcessor(c ml.Config) ImageProcessor {
imageSize: int(c.Uint("vision.image_size", 1540)), imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)), patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)), numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1024)), longestEdge: int(c.Uint("vision.longest_edge", 1540)),
} }
} }
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize}) outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
newImage := imageproc.Composite(img) newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, nil return data, nil
} }

View File

@ -2,6 +2,7 @@ package mistral3
import ( import (
"bytes" "bytes"
"fmt"
"image" "image"
"slices" "slices"
@ -59,19 +60,28 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
// Create tensor from image data // Create tensor from image data
pixelValues, err := ctx.Input().FromFloatSlice(f32s, pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize, m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
// TODO (jmorganca): this should be returned from the
// image processor instead of hardcoded
1036,
m.ImageProcessor.numChannels, m.ImageProcessor.numChannels,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
// Forward pass through vision model // Forward pass through vision model
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
// fmt.Println("visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
// Project to text embedding space // Project to text embedding space
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
// fmt.Println("visionOutputs after projector", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
return visionOutputs, nil return visionOutputs, nil
} }
@ -85,15 +95,14 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
inputMultimodal := inp.Multimodal.(ml.Tensor) inputMultimodal := inp.Multimodal.(ml.Tensor)
// Add special image tokens - using the imageTokenIndex from config // Add special image tokens - using the imageTokenIndex from config
result = append(result, result = append(result, input.Input{Token: 10}) // [IMG]
input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token result = append(result, input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}) // image data
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data result = append(result, slices.Repeat([]input.Input{{Token: 10}}, inputMultimodal.Dim(1)-1)...) // [IMG] placeholders
) result = append(result, input.Input{Token: 13}) // [IMG_END]
}
}
// Add image token placeholders fmt.Println("post tokenize", "result", result)
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
}
}
return result, nil return result, nil
} }

View File

@ -1,6 +1,7 @@
package mistral3 package mistral3
import ( import (
"fmt"
"math" "math"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -9,31 +10,109 @@ import (
var batchSize int = 1 var batchSize int = 1
type PatchMerger struct {
MergingLayer *nn.Linear `gguf:"merging_layer"`
}
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
// TODO: pass these in
w := 110
h := 74
// tokensPerImage := w * h
d := visionOutputs.Dim(0)
// TODO: handle multiple images, this currently assumes one
fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
// Reshape to [h, w, hidden_size]
imageGrid := visionOutputs.Reshape(ctx, h, w, d)
fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
// TODO: load from ml.Config
spatialMergeSize := 2
kernel := ctx.Output().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
return pm.MergingLayer.Forward(ctx, reshaped)
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"linear_1"`
Linear2 *nn.Linear `gguf:"linear_2"`
PatchMerger *PatchMerger `gguf:"patch_merger"`
spatialMergeSize int
imageTokenIndex int
hasBias bool
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
return p.Linear2.Forward(ctx, visionOutputs)
}
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
imageTokenIndex: int(c.Uint("image_token_index", 10)),
hasBias: c.Bool("mm.projector_bias", false),
}
}
type VisionSelfAttention struct { type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.headDim headDim := opts.headDim
// fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight))
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) // fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query))
key = key.Reshape(ctx, headDim, opts.numHeads, batchSize) // fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key))
value = value.Reshape(ctx, headDim, opts.numHeads, batchSize) // fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value))
ropeType := uint32(0) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
// fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query))
// fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key))
// fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value))
// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
// Multimodal rope
ropeType := uint32(24)
query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
// fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query))
// fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
// fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
// fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
@ -54,7 +133,7 @@ type VisionEncoderLayer struct {
SelfAttention *VisionSelfAttention SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP `gguf:"mlp"` MLP *VisionMLP
} }
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
@ -62,6 +141,7 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml
// self attention // self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
// fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
@ -87,25 +167,36 @@ type VisionModelOptions struct {
type VisionModel struct { type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"` PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"` EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"` Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions *VisionModelOptions
} }
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesH := m.imageSize / m.patchSize numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatchesW := m.imageSize / m.patchSize numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatches := numPatchesH * numPatchesW numPatches := numPatchesH * numPatchesW
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
// fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
// fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
// Create position IDs // TODO: this seems to have incorrect output?
positions := make([]int32, numPatches) hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
for i := range positions { // fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
positions[i] = int32(i)
// Generate 4D position IDs (time, height, width, extra) for MROPE
var positions []int32
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
positions = append(positions, 0) // unused
positions = append(positions, int32(h)) // height
positions = append(positions, int32(w)) // width
positions = append(positions, 0) // unused
}
} }
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
@ -113,14 +204,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
panic(err) panic(err)
} }
// Apply encoder normalization // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps)
// Process through transformer layers
for _, layer := range m.Layers { for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions) hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
} }
// fmt.Println("after layers", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
return hiddenState return hiddenState
} }
@ -135,7 +226,7 @@ func newVisionModel(c ml.Config) *VisionModel {
imageSize: int(c.Uint("vision.image_size", 1540)), imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)), patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)), numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05), eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0), ropeBase: c.Float("vision.rope.freq_base", 10000.0),
ropeScale: c.Float("vision.rope.freq_scale", 1.0), ropeScale: c.Float("vision.rope.freq_scale", 1.0),
}, },

View File

@ -1,38 +0,0 @@
package mistral3
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Projection *nn.Linear `gguf:"projection"`
spatialMergeSize int
imageTokenIndex int
hasBias bool
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Apply normalization
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
// If the spatial merge size is > 1, average pool the patches
if p.spatialMergeSize > 1 {
// Implementation depends on how the model handles spatial merging
// For simplicity, we'll use a spatial pooling approach
visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0)
}
// Project to text embedding dimension
return p.Projection.Forward(ctx, visionOutputs)
}
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
imageTokenIndex: int(c.Uint("image_token_index", 10)),
hasBias: c.Bool("mm.projector_bias", false),
}
}