diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 7418bb12f..9aaa974ab 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -90,7 +90,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize - kernelSize := patchesPerImage * patchesPerImage / 256 + + // TODO (jmorganca): read this from the model config + // it should instead be math.Sqrt(tokens per image) + tokensPerSide := 8 + kernelSize := patchesPerImage / tokensPerSide visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0) visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)