wip
This commit is contained in:
parent
9a9944fc6b
commit
9de1410542
@ -73,6 +73,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
|
||||||
// Multimodal configuration
|
// Multimodal configuration
|
||||||
|
@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
|||||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||||
newImage := imageproc.Composite(img)
|
newImage := imageproc.Composite(img)
|
||||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic)
|
||||||
|
|
||||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||||
return data, nil
|
return data, nil
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/imageproc"
|
"github.com/ollama/ollama/model/imageproc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,8 +28,8 @@ func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.
|
|||||||
|
|
||||||
if ratio > 1.0 {
|
if ratio > 1.0 {
|
||||||
newSize = image.Point{
|
newSize = image.Point{
|
||||||
int(math.Ceil(float64(b.Max.X) / ratio)),
|
int(math.Floor(float64(b.Max.X) / ratio)),
|
||||||
int(math.Ceil(float64(b.Max.Y) / ratio)),
|
int(math.Floor(float64(b.Max.Y) / ratio)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,3 +67,30 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
|
|||||||
opts := map[string]any{}
|
opts := map[string]any{}
|
||||||
return data, opts, nil
|
return data, opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageProcessor struct {
|
||||||
|
imageSize int
|
||||||
|
patchSize int
|
||||||
|
numChannels int
|
||||||
|
longestEdge int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||||
|
return ImageProcessor{
|
||||||
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||||
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||||
|
longestEdge: int(c.Uint("vision.longest_edge", 1024)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||||
|
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
|
||||||
|
|
||||||
|
newImage := imageproc.Composite(img)
|
||||||
|
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||||
|
|
||||||
|
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
@ -1,65 +1,27 @@
|
|||||||
package mistral3
|
package mistral3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"image"
|
"image"
|
||||||
_ "image/jpeg"
|
"slices"
|
||||||
_ "image/png"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/imageproc"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
*TextModel
|
*TextModel
|
||||||
|
*VisionModel `gguf:"v,vision"`
|
||||||
|
*MultiModalProjector `gguf:"mm"`
|
||||||
|
|
||||||
ImageProcessor
|
ImageProcessor
|
||||||
|
|
||||||
// TODO: Add VisionModel field
|
|
||||||
// *VisionModel `gguf:"v,vision"`
|
|
||||||
|
|
||||||
// TODO: Add MultiModalProjector field for combining vision and text features
|
|
||||||
// *MultiModalProjector `gguf:"mm"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adding ImageProcessor struct
|
// Implement MultimodalProcessor interface
|
||||||
type ImageProcessor struct {
|
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||||
imageSize int
|
|
||||||
patchSize int
|
|
||||||
numChannels int
|
|
||||||
longestEdge int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function to create a new ImageProcessor
|
|
||||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
|
||||||
return ImageProcessor{
|
|
||||||
imageSize: int(c.Uint("vision.image_size", 1024)),
|
|
||||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
|
||||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
|
||||||
longestEdge: int(c.Uint("vision.longest_edge", 1024)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Method to process images for the model
|
|
||||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
|
||||||
// Get output size based on longest edge and patch size
|
|
||||||
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
|
|
||||||
|
|
||||||
// Resize the image
|
|
||||||
newImage := imageproc.Composite(img)
|
|
||||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
|
||||||
|
|
||||||
// Normalize image data
|
|
||||||
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Implement MultimodalProcessor interface
|
|
||||||
// var _ model.MultimodalProcessor = (*Model)(nil)
|
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
textModel, err := NewTextModel(c)
|
textModel, err := NewTextModel(c)
|
||||||
@ -68,15 +30,10 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m := &Model{
|
m := &Model{
|
||||||
TextModel: textModel,
|
TextModel: textModel,
|
||||||
// Initialize the ImageProcessor
|
VisionModel: newVisionModel(c),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
MultiModalProjector: newMultiModalProjector(c),
|
||||||
// TODO: Initialize VisionModel if present
|
|
||||||
// VisionModel: newVisionModel(c),
|
|
||||||
|
|
||||||
// TODO: Initialize MultiModalProjector
|
|
||||||
// MultiModalProjector: &MultiModalProjector{...},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||||
@ -84,37 +41,63 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement EncodeMultimodal method for processing images
|
|
||||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||||
// Check if vision model exists - return error for now
|
if len(m.VisionModel.Layers) == 0 {
|
||||||
return nil, model.ErrNoVisionModel
|
return nil, model.ErrNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
// This will be implemented when adding the vision model:
|
// Decode image
|
||||||
/*
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
// Process image
|
||||||
|
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create tensor from image data
|
||||||
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.numChannels,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward pass through vision model
|
||||||
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
|
|
||||||
|
// Project to text embedding space
|
||||||
|
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||||
|
|
||||||
|
return visionOutputs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
|
var result []input.Input
|
||||||
|
|
||||||
|
for _, inp := range inputs {
|
||||||
|
if inp.Multimodal == nil {
|
||||||
|
result = append(result, inp)
|
||||||
|
} else {
|
||||||
|
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||||
|
|
||||||
|
// Add special image tokens - using the imageTokenIndex from config
|
||||||
|
result = append(result,
|
||||||
|
input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token
|
||||||
|
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add image token placeholders
|
||||||
|
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
return result, nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.numChannels,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Will need VisionModel to process this
|
|
||||||
// visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
|
||||||
// visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs)
|
|
||||||
// return visionOutputs, nil
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||||
@ -133,8 +116,20 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add handling of multimodal inputs when vision model is added
|
// Handle multimodal inputs
|
||||||
// Set image embeddings into hidden state if present in opts.Multimodal
|
// var except []int
|
||||||
|
// hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
|
||||||
|
|
||||||
|
// for _, image := range opts.Multimodal {
|
||||||
|
// visionOutputs := image.Multimodal.(ml.Tensor)
|
||||||
|
|
||||||
|
// // Copy vision outputs into the hidden state
|
||||||
|
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
// for i := range visionOutputs.Dim(1) {
|
||||||
|
// except = append(except, image.Index+i)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
143
model/models/mistral3/model_vision.go
Normal file
143
model/models/mistral3/model_vision.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package mistral3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchSize int = 1
|
||||||
|
|
||||||
|
type VisionSelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
headDim := opts.headDim
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
key = key.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
value = value.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
ropeType := uint32(0)
|
||||||
|
query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionMLP struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionEncoderLayer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
SelfAttention *VisionSelfAttention
|
||||||
|
|
||||||
|
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *VisionMLP `gguf:"mlp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
// self attention
|
||||||
|
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
// feed forward
|
||||||
|
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModelOptions struct {
|
||||||
|
hiddenSize int
|
||||||
|
numHeads int
|
||||||
|
headDim int
|
||||||
|
intermediateSize int
|
||||||
|
imageSize int
|
||||||
|
patchSize int
|
||||||
|
numChannels int
|
||||||
|
eps float32
|
||||||
|
ropeBase float32
|
||||||
|
ropeScale float32
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModel struct {
|
||||||
|
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
||||||
|
EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"`
|
||||||
|
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
*VisionModelOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
|
numPatchesH := m.imageSize / m.patchSize
|
||||||
|
numPatchesW := m.imageSize / m.patchSize
|
||||||
|
numPatches := numPatchesH * numPatchesW
|
||||||
|
|
||||||
|
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||||
|
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||||
|
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
|
||||||
|
// Create position IDs
|
||||||
|
positions := make([]int32, numPatches)
|
||||||
|
for i := range positions {
|
||||||
|
positions[i] = int32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply encoder normalization
|
||||||
|
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
|
||||||
|
// Process through transformer layers
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVisionModel(c ml.Config) *VisionModel {
|
||||||
|
return &VisionModel{
|
||||||
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
||||||
|
VisionModelOptions: &VisionModelOptions{
|
||||||
|
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
||||||
|
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||||
|
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
||||||
|
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
||||||
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||||
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||||
|
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05),
|
||||||
|
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
||||||
|
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
38
model/models/mistral3/multimodal_proj.go
Normal file
38
model/models/mistral3/multimodal_proj.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package mistral3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MultiModalProjector struct {
|
||||||
|
Norm *nn.RMSNorm `gguf:"norm"`
|
||||||
|
Projection *nn.Linear `gguf:"projection"`
|
||||||
|
|
||||||
|
spatialMergeSize int
|
||||||
|
imageTokenIndex int
|
||||||
|
hasBias bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
// Apply normalization
|
||||||
|
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
||||||
|
|
||||||
|
// If the spatial merge size is > 1, average pool the patches
|
||||||
|
if p.spatialMergeSize > 1 {
|
||||||
|
// Implementation depends on how the model handles spatial merging
|
||||||
|
// For simplicity, we'll use a spatial pooling approach
|
||||||
|
visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project to text embedding dimension
|
||||||
|
return p.Projection.Forward(ctx, visionOutputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
|
||||||
|
return &MultiModalProjector{
|
||||||
|
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
|
||||||
|
imageTokenIndex: int(c.Uint("image_token_index", 10)),
|
||||||
|
hasBias: c.Bool("mm.projector_bias", false),
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user