Compare commits

...

57 Commits

Author SHA1 Message Date
jmorganca
5c76074f66 wip 2025-05-12 19:15:42 -07:00
Bruce MacDonald
18d52686de
Update model/models/qwen25vl/model_vision.go 2025-05-12 14:16:46 -07:00
Bruce MacDonald
2d2eb5903d use with pattern for rope 2025-05-12 14:14:03 -07:00
Bruce MacDonald
533f4c41bd add eot 2025-05-12 14:03:37 -07:00
Bruce MacDonald
31b2c06393 Update 0007-add-unpad-operator.patch 2025-05-12 13:51:46 -07:00
Bruce MacDonald
4ae23deb50 Revert "Update 0007-add-unpad-operator.patch"
This reverts commit 340359fd087dd93c99bed4b9c87ccea3e759e184.

Update 0007-add-unpad-operator.patch
2025-05-12 13:51:04 -07:00
Bruce MacDonald
5d3da85a16 remove out of date comments 2025-05-12 13:51:04 -07:00
Bruce MacDonald
8b64b456c1 Update 0007-add-unpad-operator.patch 2025-05-12 13:51:04 -07:00
Bruce MacDonald
684f0d9291 set default values for vision model in config 2025-05-12 13:51:04 -07:00
jmorganca
3308bff137 add i32 copy and argsort for cuda 2025-05-12 13:51:04 -07:00
Bruce MacDonald
bf1929a3bc Delete 0017-add-ollama-vocab-for-grammar-support.patch 2025-05-12 13:49:44 -07:00
Bruce MacDonald
1a2c413225 move mask 2025-05-12 13:49:44 -07:00
Bruce MacDonald
57279f89a2 calculate block mask once, rather than in attention 2025-05-12 13:49:44 -07:00
Bruce MacDonald
9ceee25d8b chunk vision outputs 2025-05-12 13:49:44 -07:00
Bruce MacDonald
661bf04696 add picture prefix 2025-05-12 13:49:44 -07:00
Bruce MacDonald
2521a55ae6 fixes after rebase 2025-05-12 13:49:44 -07:00
Bruce MacDonald
32948ec952 increase rope base 2025-05-12 13:49:43 -07:00
Bruce MacDonald
9876c8453a update exported functions for tests 2025-05-12 13:49:43 -07:00
Bruce MacDonald
919b3d6e21 require new engine for qwen25vl arch 2025-05-12 13:49:43 -07:00
Bruce MacDonald
16b13e0cfc Revert "ropeTheta should be 1e5"
This reverts commit cc1638b26763eae7daddd44e3975a885671ef9d3.

This reverts commit
b32385591307e2d33a8f43ce1626b529d2dac83e.
2025-05-12 13:49:43 -07:00
Bruce MacDonald
75441c56f3 add comment explaining rope theta 2025-05-12 13:49:43 -07:00
Bruce MacDonald
45f96e898d ropeTheta should be 1e5 2025-05-12 13:49:43 -07:00
Bruce MacDonald
7c555d394c simplify patch creation 2025-05-12 13:49:43 -07:00
Bruce MacDonald
39ee6d2bd0 ranges for lint 2025-05-12 13:49:43 -07:00
Bruce MacDonald
47705b5168 simplify rope changes 2025-05-12 13:49:43 -07:00
Michael Yang
698a92aa4a reverse window 2025-05-12 13:49:43 -07:00
Michael Yang
150c499cae use silu 2025-05-12 13:49:43 -07:00
Michael Yang
f1257a7de4 update vision rope theta default 2025-05-12 13:49:43 -07:00
Bruce MacDonald
b68af0370f move sdpa to model forward pass 2025-05-12 13:49:43 -07:00
Bruce MacDonald
ca981c8a49 full attn block indexes should be []int32 2025-05-12 13:49:43 -07:00
Bruce MacDonald
b3da8a319e Update model_vision.go 2025-05-12 13:49:42 -07:00
Bruce MacDonald
359e1d5b19 full attention layers 2025-05-12 13:49:42 -07:00
Michael Yang
bde6b46ce9 fix padding
padding was being added to offset but not to the running count
2025-05-12 13:49:42 -07:00
Bruce MacDonald
ff1f74534b block attention 2025-05-12 13:49:42 -07:00
Bruce MacDonald
104f802df1 remove todos 2025-05-12 13:49:42 -07:00
Bruce MacDonald
eed0ac2948 clean up vision model forward pass 2025-05-12 13:49:42 -07:00
Bruce MacDonald
fcfad744ff fix patch merger 2025-05-12 13:49:42 -07:00
Michael Yang
fb3c16f2a2 window index 2025-05-12 13:49:42 -07:00
Michael Yang
ee869f35e4 fix image processing
python built-in `round()` rounds to the nearest even number if the value
is in the middle

https://docs.python.org/3/library/functions.html#round
2025-05-12 13:49:42 -07:00
Michael Yang
ff5d1a3dc0 duplicate input embeddings 2025-05-12 13:49:42 -07:00
Michael Yang
88b231f903 use maxgridsize 2025-05-12 13:49:42 -07:00
Michael Yang
7e920c8d75 fix: patch merger and convert
convert:
- split patch embedding
- split qkv

remove duplicate PatchMerger
2025-05-12 13:49:42 -07:00
Bruce MacDonald
dd8c619fba fixes after rebase 2025-05-12 13:49:42 -07:00
Bruce MacDonald
2af76d0e7a default to 32 for vision block count 2025-05-12 13:49:42 -07:00
Bruce MacDonald
8d901825f0 reshape cos and sin 2025-05-12 13:49:41 -07:00
Bruce MacDonald
04936b719f Update model_vision.go 2025-05-12 13:49:41 -07:00
Bruce MacDonald
0f0136d419 simplify by doing operations in Go rather than with tensors
Co-Authored-By: Michael Yang <2372640+mxyng@users.noreply.github.com>
2025-05-12 13:49:41 -07:00
Bruce MacDonald
80498f76de fix build 2025-05-12 13:49:41 -07:00
Bruce MacDonald
f8b48aa784 Delete model_external_test.go 2025-05-12 13:49:41 -07:00
Bruce MacDonald
5ff0d538b0 wip: implementing rope 2025-05-12 13:49:41 -07:00
Bruce MacDonald
eedc969c35 grid refactor 2025-05-12 13:49:41 -07:00
Bruce MacDonald
963531215e update convert 2025-05-12 13:49:41 -07:00
Bruce MacDonald
3fe090f447 get patch embedding vals from config 2025-05-12 13:49:41 -07:00
Bruce MacDonald
1704072746 patch embeddings 2025-05-12 13:49:41 -07:00
Bruce MacDonald
c1f9bcb4dd restructure
image processing

Update model.go

Update model.go

Update model.go

no projector

no projector

vision model scaffold

...

...

wip

...

rebase

fix patch merger

tidy

...

Update model_vision.go

server: do not attempt to parse offset file as gguf

This logic was causing issues for me when importing a gguf that had some padding at the end of the file. The valid gguf would be read, but then it would try to read the offset as a different gguf file. This does not seem right.

Update process_image_test.go

apply norm

prompt processing

prompt processing

fix post tokenize

fix gguf padding + populate the split patch embeddings

...

...

another shot at patch embeddings

...

patch embedding

Update model_vision.go

split pixels
2025-05-12 13:49:41 -07:00
Bruce MacDonald
198b1e6db9 text model forward pass 2025-05-12 13:49:41 -07:00
Bruce MacDonald
51ad65f831 ml: structured rope config to allow specifying context len
This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters.

Key changes:
- Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling
- Add RopeType enum to formalize different RoPE implementation variants
- Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension
- Update RoPE method signature across all tensor interfaces and implementations
- Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure

This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods.
2025-05-12 13:49:41 -07:00
17 changed files with 1714 additions and 10 deletions

View File

@ -189,6 +189,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &phi3Model{}
case "Qwen2ForCausalLM":
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

View File

@ -15,6 +15,7 @@ type qwen2Model struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
@ -39,6 +40,8 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}

102
convert/convert_qwen25vl.go Normal file
View File

@ -0,0 +1,102 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
strings.NewReplacer("attn.qkv", "attn_q"),
strings.NewReplacer("attn.qkv", "attn_k"),
strings.NewReplacer("attn.qkv", "attn_v"),
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

56
convert/tensor.go Normal file
View File

@ -0,0 +1,56 @@
package convert
import (
"iter"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
for i, replacer := range replacers {
shape := slices.Clone(t.Shape())
shape[dim] = shape[dim] / uint64(len(replacers))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
tt := t.Clone()
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be written as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: replacer.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: tt,
}) {
break
}
}
}
}

View File

@ -125,6 +125,7 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3",
"mistral3",
"llama4",
"qwen25vl",
}, kv.Architecture())
}

View File

@ -0,0 +1,277 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Thu, 1 May 2025 13:45:12 -0700
Subject: [PATCH] add argsort and cuda copy for i32
---
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++++++
ggml/src/ggml-cuda/argsort.cu | 102 +++++++++++++++++++++++++++++++++-
ggml/src/ggml-cuda/cpy.cu | 49 ++++++++++++++++
3 files changed, 192 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index becdae07..7a44b6cf 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
+static void ggml_compute_forward_argsort_i32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(int32_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ }
+}
+
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_argsort_i32(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 607ded85..53b02634 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
}
}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
+ extern __shared__ int shared_mem[];
+ int * indices = shared_mem;
+
+ const int tid = threadIdx.x;
+ const int row = blockIdx.y;
+
+ // Initialize all indices, handling the case where threads < ncols_pad
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
+ }
+ __syncthreads();
+
+ // Bitonic sort
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k/2; j > 0; j /= 2) {
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ const int ij = i ^ j;
+ if (ij > i) {
+ // Only compare values within the actual data range
+ if (i < ncols && ij < ncols) {
+ if ((i & k) == 0) {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ } else {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ }
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // Write sorted indices to output, only threads handling valid data
+ for (int i = tid; i < ncols; i += blockDim.x) {
+ dst[row * ncols + i] = indices[i];
+ }
+}
+
+static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+ // Bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ // Ensure thread count doesn't exceed maximum (typically 1024)
+ const int max_threads = 1024; // This is the typical max for most GPUs
+ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
+
+ const dim3 block_dims(threads_per_block, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ // Check if shared memory size is within limits
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ // Instead of logging an error, use GGML_ASSERT with a descriptive message
+ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
+
+ // Launch kernels with the updated thread configuration
+ if (order == GGML_SORT_ORDER_ASC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+}
+
+
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ if (src0->type == GGML_TYPE_I32) {
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ } else {
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 2d46176e..47383486 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
*dsti = *xi;
}
+static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+ const int32_t * xi = (const int32_t *) cxi;
+ int32_t * dsti = (int32_t *) cdsti;
+
+ *dsti = *xi;
+}
+
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
+// First, add this template function after the other template functions
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
+static void ggml_cpy_i32_i32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ return (void*) cpy_i32_i32<cpy_1_i32_i32>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@ -119,6 +119,25 @@ type Context interface {
Layer(int) Context
}
// RopeOpts contains optional parameters for RoPE function
type RopeOpts struct {
DefaultContextLen uint32
YarnExtFactor float32
YarnAttnFactor float32
YarnBetaFast float32
YarnBetaSlow float32
}
// RopeOption defines a function that modifies RopeOpts
type RopeOption func(*RopeOpts)
// WithContextLen sets a custom context length
func WithContextLen(len uint32) RopeOption {
return func(opts *RopeOpts) {
opts.DefaultContextLen = len
}
}
type Tensor interface {
Dim(n int) int
Stride(n int) int
@ -144,7 +163,8 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, sections [4]int32, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor
@ -173,6 +193,7 @@ type Tensor interface {
Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@ -1071,7 +1071,21 @@ const (
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
// Default options
opts := &ml.RopeOpts{
DefaultContextLen: 131072,
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
// Apply any provided options
for _, option := range options {
option(opts)
}
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
@ -1084,16 +1098,50 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.int(128000),
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
C.float(opts.YarnExtFactor),
C.float(opts.YarnAttnFactor),
C.float(opts.YarnBetaFast),
C.float(opts.YarnBetaSlow),
),
}
}
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int32, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(&sections[0]),
C.int(ropeType),
C.int(128000), // Default context length
C.float(ropeBase),
C.float(ropeScale),
C.float(0.0), // ext_factor
C.float(1.0), // attn_factor
C.float(32.0), // beta_fast
C.float(1.0), // beta_slow
),
}
}
@ -1187,3 +1235,10 @@ func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}
func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}

View File

@ -6877,6 +6877,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
static void ggml_compute_forward_argsort_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(int32_t));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
// C doesn't have a functional sort, so we do a bubble sort instead
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) {
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
int32_t tmp = dst_data[j];
dst_data[j] = dst_data[k];
dst_data[k] = tmp;
}
}
}
}
}
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -6888,6 +6927,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_argsort_i32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");

View File

@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
}
}
template<ggml_sort_order order>
static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
extern __shared__ int shared_mem[];
int * indices = shared_mem;
const int tid = threadIdx.x;
const int row = blockIdx.y;
// Initialize all indices, handling the case where threads < ncols_pad
for (int i = tid; i < ncols_pad; i += blockDim.x) {
indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
}
__syncthreads();
// Bitonic sort
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k/2; j > 0; j /= 2) {
for (int i = tid; i < ncols_pad; i += blockDim.x) {
const int ij = i ^ j;
if (ij > i) {
// Only compare values within the actual data range
if (i < ncols && ij < ncols) {
if ((i & k) == 0) {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
} else {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
}
}
}
}
__syncthreads();
}
}
// Write sorted indices to output, only threads handling valid data
for (int i = tid; i < ncols; i += blockDim.x) {
dst[row * ncols + i] = indices[i];
}
}
static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// Bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
// Ensure thread count doesn't exceed maximum (typically 1024)
const int max_threads = 1024; // This is the typical max for most GPUs
const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
const dim3 block_dims(threads_per_block, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// Check if shared memory size is within limits
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
// Instead of logging an error, use GGML_ASSERT with a descriptive message
GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
// Launch kernels with the updated thread configuration
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
if (src0->type == GGML_TYPE_I32) {
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}
}

View File

@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
*dsti = *xi;
}
static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
const int32_t * xi = (const int32_t *) cxi;
int32_t * dsti = (int32_t *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
// First, add this template function after the other template functions
template <cpy_kernel_t cpy_1>
static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
static void ggml_cpy_i32_i32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_i32_i32<cpy_1_i32_i32>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@ -7,4 +7,5 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/qwen25vl"
)

View File

@ -0,0 +1,192 @@
package qwen25vl
import (
"bytes"
"fmt"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
TextModel: NewTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, nil, err
}
// Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
return pixelValues, grid, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
pixels, grid, err := m.PixelValues(ctx, multimodalData)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return &chunks{Model: m, Tensor: visionOutputs, grid: grid}, nil
}
type chunks struct {
*Model
ml.Tensor
grid *Grid
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
var (
imageToken int32 = 151655
visionStartToken int32 = 151652
visionEndToken int32 = 151653
)
nImg := 0
for _, inp := range inputs {
if inp.Multimodal == nil {
// If not a multimodal input, add it to the result unchanged
result = append(result, inp)
} else {
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
// the image tokens with a prompt, so we add a prefix here
nImg++
pre, err := m.TextModel.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
if err != nil {
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
}
for i := range pre {
result = append(result, input.Input{Token: pre[i]})
}
// This is an image token with multimodal data
chunksData := inp.Multimodal.(*chunks)
patchesPerChunk := chunksData.Dim(1)
// First add the vision start token
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})
// Add the image token with the multimodal tensor data at the first position
// Create a chunk with proper s and n values
result = append(result, input.Input{
Token: imageToken,
Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
MultimodalHash: inp.MultimodalHash,
SameBatch: patchesPerChunk,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
fmt.Println("Forward")
pos := make([]int32, len(batch.Positions)*4)
var grid = &Grid{}
if len(batch.Multimodal) > 0 {
image := batch.Multimodal[0].Multimodal
grid = image.(*chunk).chunks.grid
for y := 0; y < grid.Height/2; y++ {
for x := 0; x < grid.Width/2; x++ {
i := y*grid.Width/2 + x
pos[i] = batch.Positions[i]
pos[i+len(batch.Positions)] = batch.Positions[i] + int32(y)
pos[i+len(batch.Positions)*2] = batch.Positions[i] + int32(x)
pos[i+len(batch.Positions)*3] = 0
}
}
} else {
copy(pos[:len(batch.Positions)], batch.Positions)
copy(pos[len(batch.Positions):len(batch.Positions)*2], batch.Positions)
copy(pos[len(batch.Positions)*2:len(batch.Positions)*3], batch.Positions)
}
positions, err := ctx.Input().FromIntSlice(pos, len(pos))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
}
func init() {
model.Register("qwen25vl", New)
}

View File

@ -0,0 +1,176 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim, defaultContextLen uint32
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func NewTextModel(c fs.Config) *TextModel {
m := TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count", 128),
defaultContextLen: c.Uint("context_length", 128000),
},
}
return &m
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
sections := [4]int32{16, 24, 24, 0}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
// Process through transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}

View File

@ -0,0 +1,395 @@
package qwen25vl
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// We only support batch size of 1
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
// Create a flat slice for the mask (all -inf initially to block all attention)
flat := make([]float32, seqLength*seqLength)
for i := range flat {
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
}
// Fill in the mask with zeros for tokens that CAN attend to each other
for i := 1; i < len(bounds); i++ {
start := bounds[i-1]
end := bounds[i]
// Enable attention within this sequence block by setting values to 0
for row := start; row < end; row++ {
for col := start; col < end; col++ {
idx := row*seqLength + col
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
}
}
}
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
if err != nil {
panic(err)
}
// Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
return mask
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the multi-layer perceptron
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
patchSize int
numChannels int
eps float32
ropeTheta float32
spatialMergeSize int
windowSize int
fullAttnBlocks []int
temporalPatchSize int
}
type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into parts for the temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
// Add the outputs from the two temporal convolutions
out := out0.Add(ctx, out1)
// Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, opts.hiddenSize, numPatches)
}
// VisionPatchMerger implements patch merging for the Qwen vision model
type VisionPatchMerger struct {
LNQ *nn.RMSNorm `gguf:"ln_q"`
MLP0 *nn.Linear `gguf:"mlp.0"`
MLP2 *nn.Linear `gguf:"mlp.2"`
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
output := pm.MLP2.Forward(ctx, activated)
return output
}
// VisionModel implements the Qwen vision model
type VisionModel struct {
PatchEmbedding *PatchEmbedding
Layers []VisionEncoderLayer `gguf:"blk"`
PatchMerger *VisionPatchMerger `gguf:"merger"`
*VisionModelOptions
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.PositionalEmbedding(ctx, grid)
windowIndex, bounds := m.WindowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
// Apply encoder layers
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, i) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
} else {
hiddenStates = layer.Forward(
ctx,
hiddenStates,
cos,
sin,
mask,
m.VisionModelOptions,
)
}
}
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
reverseWindowIndex := windowIndex.Argsort(ctx)
return hiddenStates.Rows(ctx, reverseWindowIndex)
}
// WindowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared
//
// The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data.
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
llmGridH := grid.Height / m.spatialMergeSize
llmGridW := grid.Width / m.spatialMergeSize
// Calculate window parameters
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
// Initialize index_new slice
var index []int32
// Initialize bounds with the first element as 0
bounds := []int{0}
totalSeqLen := 0
// Process each window without padding
for wh := range numWindowsH {
for ww := range numWindowsW {
// Calculate window boundaries
hStart := wh * vitMergerWindowSize
wStart := ww * vitMergerWindowSize
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Calculate sequence length for this window
seqLen := (hEnd - hStart) * (wEnd - wStart)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
}
}
totalSeqLen += seqLen
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
}
}
t, err := ctx.Input().FromIntSlice(index, len(index))
if err != nil {
panic(err)
}
return t, bounds
}
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
for j := range freq {
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
}
}
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
if err != nil {
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
}
// Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange()
coords := make([]int32, 0, grid.Height*grid.Width*2)
for y := range grid.Height {
for x := range grid.Width {
coords = append(coords, int32(y), int32(x))
}
}
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
if err != nil {
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
}
// Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
// Use position indices to look up corresponding frequency values
positionalEmbedding := freqs.Rows(ctx, pos)
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
return positionalEmbedding
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
fullAttnBlocks := c.Ints("qwen25vl.vision.fullatt_block_indexes", []int32{7, 15, 23, 31})
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
patchSize: patchSize,
numChannels: numChannels,
eps: eps,
ropeTheta: ropeTheta,
spatialMergeSize: spatialMergeSize,
windowSize: windowSize,
temporalPatchSize: temporalPatchSize,
},
}
for i := range fullAttnBlocks {
// full attention block indexes have to be converted to int for use with the slices package
model.fullAttnBlocks = append(model.fullAttnBlocks, int(fullAttnBlocks[i]))
}
return model
}

View File

@ -0,0 +1,186 @@
package qwen25vl
import (
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
type ImageProcessor struct {
imageSize int
numChannels int
patchSize int
temporalPatchSize int
mergeSize int
minPixels int
maxPixels int
factor int
rescaleFactor float32
imageMean []float32
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 560)),
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
patchSize: patchSize,
temporalPatchSize: 2,
mergeSize: mergeSize,
minPixels: 56 * 56,
maxPixels: 28 * 28 * 4 * 1280,
factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0,
imageMean: []float32{0.48145466, 0.4578275, 0.40821073},
imageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// SmartResize implements the smart resize algorithm
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
if height < factor || width < factor {
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
} else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 {
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio))
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
if hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
type Grid struct {
Height int
Width int
Temporal int
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize(
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // For single images, temporal dimension is 1
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
}
// Return patches and grid dimensions
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
// Extract patch data for first temporal frame
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
// Calculate source pixel coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// Source index in input tensor (CHW format)
srcIdx := c*height*width + y*width + x
// Destination index in first temporal frame
dstIdx := channelOffset + (py * patchSize) + px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
// Copy first temporal frame to all other frames
if temporalPatchSize > 1 {
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
firstFrameOffset := channelOffset
frameSize := patchSize * patchSize
// Copy first frame to all other frames
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[firstFrameOffset:firstFrameOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@ -0,0 +1,47 @@
package qwen25vl
import (
"image"
_ "image/jpeg" // Register JPEG decoder
"testing"
)
func TestSmartResize(t *testing.T) {
type smartResizeCase struct {
TestImage image.Image
Expected image.Point
}
// Create an image processor with default values
processor := ImageProcessor{
imageSize: 560, // Example value
numChannels: 3,
factor: 28,
minPixels: 56 * 56,
maxPixels: 14 * 14 * 4 * 1280,
}
cases := []smartResizeCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 1024)),
Expected: image.Point{980, 980},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
Expected: image.Point{1036, 756},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
Expected: image.Point{980, 980},
},
}
for _, c := range cases {
b := c.TestImage.Bounds().Max
x, y := processor.SmartResize(b.X, b.Y)
actual := image.Point{x, y}
if actual != c.Expected {
t.Errorf("expected: %v, actual: %v", c.Expected, actual)
}
}
}