From 9de1410542ce3367f6d6a8cec38fa858948b562c Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 21 Mar 2025 13:17:13 -0700 Subject: [PATCH] wip --- convert/convert_mistral.go | 1 + model/models/gemma3/process_image.go | 2 +- model/models/mistral3/imageproc.go | 32 ++++- model/models/mistral3/model.go | 159 +++++++++++------------ model/models/mistral3/model_vision.go | 143 ++++++++++++++++++++ model/models/mistral3/multimodal_proj.go | 38 ++++++ 6 files changed, 290 insertions(+), 85 deletions(-) create mode 100644 model/models/mistral3/model_vision.go create mode 100644 model/models/mistral3/multimodal_proj.go diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index 99032b51c..57e3d4ba4 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -73,6 +73,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels + // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta // Multimodal configuration diff --git a/model/models/gemma3/process_image.go b/model/models/gemma3/process_image.go index fe8269a3b..1dc7259f9 100644 --- a/model/models/gemma3/process_image.go +++ b/model/models/gemma3/process_image.go @@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { outputSize := image.Point{p.imageSize, p.imageSize} newImage := imageproc.Composite(img) - newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) + newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic) data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD) return data, nil diff --git a/model/models/mistral3/imageproc.go b/model/models/mistral3/imageproc.go index 78c1ddf7c..2caa54091 100644 --- a/model/models/mistral3/imageproc.go +++ b/model/models/mistral3/imageproc.go @@ -8,6 +8,7 @@ import ( "io" "math" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/imageproc" ) @@ -27,8 +28,8 @@ func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image. if ratio > 1.0 { newSize = image.Point{ - int(math.Ceil(float64(b.Max.X) / ratio)), - int(math.Ceil(float64(b.Max.Y) / ratio)), + int(math.Floor(float64(b.Max.X) / ratio)), + int(math.Floor(float64(b.Max.Y) / ratio)), } } @@ -66,3 +67,30 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) { opts := map[string]any{} return data, opts, nil } + +type ImageProcessor struct { + imageSize int + patchSize int + numChannels int + longestEdge int +} + +func newImageProcessor(c ml.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + longestEdge: int(c.Uint("vision.longest_edge", 1024)), + } +} + +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { + outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize}) + + newImage := imageproc.Composite(img) + newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) + + data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) + + return data, nil +} diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index e22c2c95d..c5e484e66 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -1,65 +1,27 @@ package mistral3 import ( + "bytes" "image" - _ "image/jpeg" - _ "image/png" + "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" - "github.com/ollama/ollama/model/imageproc" "github.com/ollama/ollama/model/input" ) type Model struct { model.Base *TextModel + *VisionModel `gguf:"v,vision"` + *MultiModalProjector `gguf:"mm"` ImageProcessor - - // TODO: Add VisionModel field - // *VisionModel `gguf:"v,vision"` - - // TODO: Add MultiModalProjector field for combining vision and text features - // *MultiModalProjector `gguf:"mm"` } -// Adding ImageProcessor struct -type ImageProcessor struct { - imageSize int - patchSize int - numChannels int - longestEdge int -} - -// Function to create a new ImageProcessor -func newImageProcessor(c ml.Config) ImageProcessor { - return ImageProcessor{ - imageSize: int(c.Uint("vision.image_size", 1024)), - patchSize: int(c.Uint("vision.patch_size", 16)), - numChannels: int(c.Uint("vision.num_channels", 3)), - longestEdge: int(c.Uint("vision.longest_edge", 1024)), - } -} - -// Method to process images for the model -func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { - // Get output size based on longest edge and patch size - outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize}) - - // Resize the image - newImage := imageproc.Composite(img) - newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) - - // Normalize image data - data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - - return data, nil -} - -// TODO: Implement MultimodalProcessor interface -// var _ model.MultimodalProcessor = (*Model)(nil) +// Implement MultimodalProcessor interface +var _ model.MultimodalProcessor = (*Model)(nil) func New(c ml.Config) (model.Model, error) { textModel, err := NewTextModel(c) @@ -68,15 +30,10 @@ func New(c ml.Config) (model.Model, error) { } m := &Model{ - TextModel: textModel, - // Initialize the ImageProcessor - ImageProcessor: newImageProcessor(c), - - // TODO: Initialize VisionModel if present - // VisionModel: newVisionModel(c), - - // TODO: Initialize MultiModalProjector - // MultiModalProjector: &MultiModalProjector{...}, + TextModel: textModel, + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), } m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) @@ -84,37 +41,63 @@ func New(c ml.Config) (model.Model, error) { return m, nil } -// Implement EncodeMultimodal method for processing images func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { - // Check if vision model exists - return error for now - return nil, model.ErrNoVisionModel + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } - // This will be implemented when adding the vision model: - /* - image, _, err := image.Decode(bytes.NewReader(multimodalData)) - if err != nil { - return nil, err + // Decode image + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + // Process image + f32s, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + // Create tensor from image data + pixelValues, err := ctx.Input().FromFloatSlice(f32s, + m.ImageProcessor.imageSize, + m.ImageProcessor.imageSize, + m.ImageProcessor.numChannels, + ) + if err != nil { + return nil, err + } + + // Forward pass through vision model + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + + // Project to text embedding space + visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps) + + return visionOutputs, nil +} + +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + var result []input.Input + + for _, inp := range inputs { + if inp.Multimodal == nil { + result = append(result, inp) + } else { + inputMultimodal := inp.Multimodal.(ml.Tensor) + + // Add special image tokens - using the imageTokenIndex from config + result = append(result, + input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token + input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data + ) + + // Add image token placeholders + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) } + } - f32s, err := m.ImageProcessor.ProcessImage(image) - if err != nil { - return nil, err - } - - pixelValues, err := ctx.Input().FromFloatSlice(f32s, - m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, - m.ImageProcessor.numChannels, - ) - if err != nil { - return nil, err - } - - // Will need VisionModel to process this - // visionOutputs := m.VisionModel.Forward(ctx, pixelValues) - // visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs) - // return visionOutputs, nil - */ + return result, nil } func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { @@ -133,8 +116,20 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - // TODO: Add handling of multimodal inputs when vision model is added - // Set image embeddings into hidden state if present in opts.Multimodal + // Handle multimodal inputs + // var except []int + // hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs) + + // for _, image := range opts.Multimodal { + // visionOutputs := image.Multimodal.(ml.Tensor) + + // // Copy vision outputs into the hidden state + // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + // for i := range visionOutputs.Dim(1) { + // except = append(except, image.Index+i) + // } + // } return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go new file mode 100644 index 000000000..2826efe81 --- /dev/null +++ b/model/models/mistral3/model_vision.go @@ -0,0 +1,143 @@ +package mistral3 + +import ( + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +var batchSize int = 1 + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + headDim := opts.headDim + + query := sa.Query.Forward(ctx, hiddenState) + key := sa.Key.Forward(ctx, hiddenState) + value := sa.Value.Forward(ctx, hiddenState) + + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, batchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, batchSize) + + ropeType := uint32(0) + query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type VisionEncoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *VisionSelfAttention + + FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *VisionMLP `gguf:"mlp"` +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenState + + // self attention + hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts) + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + // feed forward + hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = e.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +type VisionModelOptions struct { + hiddenSize int + numHeads int + headDim int + intermediateSize int + imageSize int + patchSize int + numChannels int + eps float32 + ropeBase float32 + ropeScale float32 +} + +type VisionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_conv"` + EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"` + Layers []VisionEncoderLayer `gguf:"blk"` + + *VisionModelOptions +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + numPatchesH := m.imageSize / m.patchSize + numPatchesW := m.imageSize / m.patchSize + numPatches := numPatchesH * numPatchesW + + hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) + hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + // Create position IDs + positions := make([]int32, numPatches) + for i := range positions { + positions[i] = int32(i) + } + + positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) + if err != nil { + panic(err) + } + + // Apply encoder normalization + hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps) + + // Process through transformer layers + for _, layer := range m.Layers { + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions) + } + + return hiddenState +} + +func newVisionModel(c ml.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1024)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + headDim: int(c.Uint("vision.attention.key_length", 64)), + intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)), + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05), + ropeBase: c.Float("vision.rope.freq_base", 10000.0), + ropeScale: c.Float("vision.rope.freq_scale", 1.0), + }, + } +} diff --git a/model/models/mistral3/multimodal_proj.go b/model/models/mistral3/multimodal_proj.go new file mode 100644 index 000000000..7de40abd7 --- /dev/null +++ b/model/models/mistral3/multimodal_proj.go @@ -0,0 +1,38 @@ +package mistral3 + +import ( + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Projection *nn.Linear `gguf:"projection"` + + spatialMergeSize int + imageTokenIndex int + hasBias bool +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { + // Apply normalization + visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) + + // If the spatial merge size is > 1, average pool the patches + if p.spatialMergeSize > 1 { + // Implementation depends on how the model handles spatial merging + // For simplicity, we'll use a spatial pooling approach + visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0) + } + + // Project to text embedding dimension + return p.Projection.Forward(ctx, visionOutputs) +} + +func newMultiModalProjector(c ml.Config) *MultiModalProjector { + return &MultiModalProjector{ + spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), + imageTokenIndex: int(c.Uint("image_token_index", 10)), + hasBias: c.Bool("mm.projector_bias", false), + } +}