From cfeca2713307e1670820ad5781817739e04fc767 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 23 Mar 2025 01:01:23 -0700 Subject: [PATCH] wip --- model/models/mistral3/model.go | 5 +- model/models/mistral3/model_vision.go | 76 ++++++++------------------- 2 files changed, 23 insertions(+), 58 deletions(-) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 713a1bcbb..2592e1516 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -59,10 +59,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er // Create tensor from image data pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.ImageProcessor.imageSize, - - // TODO (jmorganca): this should be returned from the - // image processor instead of hardcoded - 1036, + 1036, // TODO (jmorganca): this should be returned from ProcessImage m.ImageProcessor.numChannels, ) if err != nil { diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 29a96e5fd..8561ec358 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -1,6 +1,7 @@ package mistral3 import ( + "fmt" "math" "github.com/ollama/ollama/ml" @@ -55,11 +56,9 @@ type MultiModalProjector struct { func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) - // fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs) - // fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) - visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx) - // fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + visionOutputs = p.Linear1.Forward(ctx, visionOutputs) + visionOutputs = visionOutputs.GELU(ctx) return p.Linear2.Forward(ctx, visionOutputs) } @@ -79,40 +78,20 @@ type VisionSelfAttention struct { } func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { - headDim := opts.headDim + q := sa.Query.Forward(ctx, hiddenState) + k := sa.Key.Forward(ctx, hiddenState) + v := sa.Value.Forward(ctx, hiddenState) - // fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight)) + q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize) + k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize) + v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize) - query := sa.Query.Forward(ctx, hiddenState) - key := sa.Key.Forward(ctx, hiddenState) - value := sa.Value.Forward(ctx, hiddenState) + ropeType := uint32(24) // 2d vision rope + q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) - // fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query)) - // fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key)) - // fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value)) - - query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) - key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) - value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - - // fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query)) - // fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key)) - // fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value)) - // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs)) - - // Multimodal rope - ropeType := uint32(24) - query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) - key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) - - // fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query)) - // fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key)) - - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) - // fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention)) + attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) - // fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention)) - return sa.Output.Forward(ctx, attention) } @@ -130,22 +109,19 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio type VisionEncoderLayer struct { AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` SelfAttention *VisionSelfAttention - - FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` - MLP *VisionMLP + FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *VisionMLP } func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenState - // self attention hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - // fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) + fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts) hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState - // feed forward hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.MLP.Forward(ctx, hiddenState, opts) return hiddenState.Add(ctx, residual) @@ -177,24 +153,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { numPatchesW := pixelValues.Dim(0) / m.patchSize numPatches := numPatchesH * numPatchesW hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) - // fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) - // fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - // fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) - - // TODO: this seems to have incorrect output? hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps) - // fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) - // Generate 4D position IDs (time, height, width, extra) for MROPE - var positions []int32 + totalPositions := numPatchesH * numPatchesW + positions := make([]int32, totalPositions*4) + for h := 0; h < numPatchesH; h++ { for w := 0; w < numPatchesW; w++ { - positions = append(positions, 0) // unused - positions = append(positions, int32(h)) // height - positions = append(positions, int32(w)) // width - positions = append(positions, 0) // unused + index := h*numPatchesW + w + positions[totalPositions+index] = int32(h) + positions[totalPositions*2+index] = int32(w) } } @@ -203,8 +173,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { panic(err) } - // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs)) - for _, layer := range m.Layers { hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions) }