This commit is contained in:
jmorganca 2025-03-22 23:20:39 -07:00
parent 8dd2a81f8c
commit 4530661799
3 changed files with 13 additions and 17 deletions

View File

@ -138,10 +138,10 @@ func (p *mistral3Model) Replacements() []string {
"attention.v_proj", "attn_v", "attention.v_proj", "attn_v",
"attention.o_proj", "attn_output", "attention.o_proj", "attn_output",
"attention_norm", "attn_norm", "attention_norm", "attn_norm",
"feed_forward", "mlp",
"feed_forward.gate_proj", "ffn_gate", "feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down", "feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up", "feed_forward.up_proj", "ffn_up",
"patch_merger.merging_layer", "merger",
"multi_modal_projector", "mm", "multi_modal_projector", "mm",
"ffn_norm", "ffn_norm", "ffn_norm", "ffn_norm",
"lm_head", "output", "lm_head", "output",

View File

@ -2,7 +2,6 @@ package mistral3
import ( import (
"bytes" "bytes"
"fmt"
"image" "image"
"slices" "slices"
@ -70,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err return nil, err
} }
fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues)) // fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
// Forward pass through vision model // Forward pass through vision model
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
@ -102,8 +101,6 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
} }
fmt.Println("post tokenize", "result", result)
return result, nil return result, nil
} }

View File

@ -1,7 +1,6 @@
package mistral3 package mistral3
import ( import (
"fmt"
"math" "math"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -22,23 +21,23 @@ func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tenso
d := visionOutputs.Dim(0) d := visionOutputs.Dim(0)
// TODO: handle multiple images, this currently assumes one // TODO: handle multiple images, this currently assumes one
fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) // fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
// Reshape to [h, w, hidden_size] // Reshape to [h, w, hidden_size]
imageGrid := visionOutputs.Reshape(ctx, h, w, d) imageGrid := visionOutputs.Reshape(ctx, h, w, d)
fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid)) // fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
// TODO: load from ml.Config // TODO: load from config
spatialMergeSize := 2 spatialMergeSize := 2
kernel := ctx.Output().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1) kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel)) // fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches)) // fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2)) // fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped)) // fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
return pm.MergingLayer.Forward(ctx, reshaped) return pm.MergingLayer.Forward(ctx, reshaped)
} }
@ -56,11 +55,11 @@ type MultiModalProjector struct {
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) // fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs) visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) // fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx) visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) // fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
return p.Linear2.Forward(ctx, visionOutputs) return p.Linear2.Forward(ctx, visionOutputs)
} }