use 2d pooling

This commit is contained in:
Michael Yang 2025-03-11 09:00:10 -07:00
parent ab39e08eb9
commit 63a394068c
4 changed files with 36 additions and 25 deletions

View File

@ -26,15 +26,16 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3 NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14 PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"` } `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"` NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"` NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"` FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"` RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"` RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"` SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
} }
const ( const (
@ -102,6 +103,10 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256) kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
} }
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv return kv
} }

View File

@ -135,7 +135,7 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor
AvgPool1D(ctx Context, k, s, p int) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor

View File

@ -247,7 +247,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
createTensor(tensor{source: t}, output.bts) createTensor(tensor{source: t}, output.bts)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible // TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, input.bts) createTensor(tensor{source: t}, output.bts)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
// these tensors should be repeated per layer // these tensors should be repeated per layer
for i, layer := range layers { for i, layer := range layers {
@ -952,10 +952,10 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
} }
} }
func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor { func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)), t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
} }
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"hash/fnv" "hash/fnv"
"image" "image"
"math"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -30,9 +31,21 @@ var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct { type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"` SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"` InputProjection *nn.Linear `gguf:"mm_input_projection"`
tokensPerImage int
} }
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps) visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs // TODO: inputProjection must be transposed since they're incompatible with visionOutputs
@ -59,6 +72,9 @@ func New(c ml.Config) (model.Model, error) {
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),
TextModel: newTextModel(c), TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
} }
slidingWindowLen := int32(c.Uint("attention.sliding_window")) slidingWindowLen := int32(c.Uint("attention.sliding_window"))
@ -88,17 +104,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
} }
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
// TODO (jmorganca): read this from the model config
// it should instead be math.Sqrt(tokens per image)
tokensPerSide := 8
kernelSize := patchesPerImage / tokensPerSide
visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
return visionOutputs, nil return visionOutputs, nil
} }