diff --git a/convert/convert.go b/convert/convert.go index 48804d7f3..309b0ce19 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -191,6 +191,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &phi3Model{} case "Qwen2ForCausalLM": conv = &qwen2Model{} + case "Qwen2_5_VLForConditionalGeneration": + conv = &qwen25VLModel{} case "BertModel": conv = &bertModel{} case "CohereForCausalLM": diff --git a/convert/convert_qwen2.go b/convert/convert_qwen2.go index edcb82e29..3647c4e54 100644 --- a/convert/convert_qwen2.go +++ b/convert/convert_qwen2.go @@ -15,6 +15,7 @@ type qwen2Model struct { Type string `json:"type"` Factor ropeFactor `json:"factor"` OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + MropeSection []int32 `json:"mrope_section"` } `json:"rope_scaling"` RMSNormEPS float32 `json:"rms_norm_eps"` } @@ -39,6 +40,8 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV { case "yarn": kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor + case "mrope", "default": + kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection default: panic("unknown rope scaling type") } diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go new file mode 100644 index 000000000..c2d5a633b --- /dev/null +++ b/convert/convert_qwen25vl.go @@ -0,0 +1,102 @@ +package convert + +import ( + "cmp" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type qwen25VLModel struct { + qwen2Model + + VisionModel struct { + Depth uint32 `json:"depth"` + HiddenSize uint32 `json:"hidden_size"` + NumHeads uint32 `json:"num_heads"` + InChannels uint32 `json:"in_chans"` + PatchSize uint32 `json:"patch_size"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + SpatialPatchSize uint32 `json:"spatial_patch_size"` + WindowSize uint32 `json:"window_size"` + RMSNormEps float32 `json:"layer_norm_epsilon"` + RopeTheta float32 `json:"rope_theta"` + FullAttentionBlocks []int32 `json:"fullatt_block_indexes"` + TemporalPatchSize uint32 `json:"temporal_patch_size"` + } `json:"vision_config"` +} + +var _ ModelConverter = (*qwen25VLModel)(nil) + +func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV { + kv := q.ModelParameters.KV(t) + kv["general.architecture"] = "qwen25vl" + + for k, v := range q.qwen2Model.KV(t) { + if strings.HasPrefix(k, "qwen2.") { + kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v + } + } + + if q.VisionModel.FullAttentionBlocks == nil { + kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31} + } + + kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32) + kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize + kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16) + kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels + kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14) + kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2) + kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize + kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112) + kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6) + kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4) + kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks + kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2) + + return kv +} + +func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + for _, t := range ts { + if strings.Contains(t.Name(), "patch_embed.proj") { + for t := range splitDim(t, 2, + strings.NewReplacer("patch_embed.proj", "patch_embd_0"), + strings.NewReplacer("patch_embed.proj", "patch_embd_1"), + ) { + t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 }) + out = append(out, t) + } + } else if strings.Contains(t.Name(), "attn.qkv") { + out = append(out, slices.Collect(splitDim(t, 0, + strings.NewReplacer("attn.qkv", "attn_q"), + strings.NewReplacer("attn.qkv", "attn_k"), + strings.NewReplacer("attn.qkv", "attn_v"), + ))...) + } else { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + } + + return out +} + +func (p *qwen25VLModel) Replacements() []string { + return append( + p.qwen2Model.Replacements(), + "visual", "v", + "blocks", "blk", + "attn.proj", "attn_out", + "norm1", "ln1", + "norm2", "ln2", + ) +} diff --git a/convert/tensor.go b/convert/tensor.go new file mode 100644 index 000000000..ffb22ead9 --- /dev/null +++ b/convert/tensor.go @@ -0,0 +1,56 @@ +package convert + +import ( + "iter" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" +) + +// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension +// is split evenly based on the number of replacers provided. +func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] { + return func(yield func(*ggml.Tensor) bool) { + for i, replacer := range replacers { + shape := slices.Clone(t.Shape()) + shape[dim] = shape[dim] / uint64(len(replacers)) + + slice := slices.Repeat([]tensor.Slice{nil}, len(shape)) + slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim])) + + tt := t.Clone() + tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { + dims := make([]int, len(shape)) + for i := range shape { + dims[i] = int(shape[i]) + } + + var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + t, err := t.Slice(slice...) + if err != nil { + return nil, err + } + + t = tensor.Materialize(t) + // flatten tensor so it can be written as a vector + if err := t.Reshape(t.Shape().TotalSize()); err != nil { + return nil, err + } + + return native.VectorF32(t.(*tensor.Dense)) + }) + + if !yield(&ggml.Tensor{ + Name: replacer.Replace(t.Name()), + Kind: t.Kind(), + Shape: shape, + WriterTo: tt, + }) { + break + } + } + } +} diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index c29d715bd..514b60115 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + "math" "slices" "strings" @@ -126,6 +127,7 @@ func (kv KV) OllamaEngineRequired() bool { "mistral3", "llama4", "mllama", + "qwen25vl", }, kv.Architecture()) } @@ -649,6 +651,29 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { graphSize = 4 * (imageSize*imageSize*numChannels + embeddingLength*patchSize + numPatches*numPatches*headCount) + case "qwen25vl": + maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280)) + mergeSize := uint64(llm.KV().Uint("vision.spatial_merge_size", 2)) + temporalPatchSize := uint64(2) + + // Calculate max possible patches based on max_pixels + maxHeight := uint64(math.Sqrt(float64(maxPixels))) + maxWidth := maxPixels / maxHeight + maxGridHeight := maxHeight / patchSize + maxGridWidth := maxWidth / patchSize + // Account for merged patches (2x2 grid) + numPatches := (maxGridHeight * maxGridWidth) / (mergeSize * mergeSize) + + // Calculate graph size based on typical operations in ProcessImage and createPatches + graphSize = 4 * (maxPixels*numChannels + // Original image storage + // Normalized pixels + maxPixels*numChannels + + // Patches storage (numPatches * channels * temporalPatchSize * patchSize^2) + numPatches*numChannels*temporalPatchSize*patchSize*patchSize + + // Self-attention calculations (similar to other architectures) + numPatches*numPatches*headCount + + // Additional buffer for processing + embeddingLength*numPatches) case "llama4": // vision graph is computed independently in the same schedule // and is negligible compared to the worst case text graph diff --git a/llama/patches/0015-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0015-add-argsort-and-cuda-copy-for-i32.patch new file mode 100644 index 000000000..b71295c76 --- /dev/null +++ b/llama/patches/0015-add-argsort-and-cuda-copy-for-i32.patch @@ -0,0 +1,277 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Thu, 1 May 2025 13:45:12 -0700 +Subject: [PATCH] add argsort and cuda copy for i32 + +--- + ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++++++ + ggml/src/ggml-cuda/argsort.cu | 102 +++++++++++++++++++++++++++++++++- + ggml/src/ggml-cuda/cpy.cu | 49 ++++++++++++++++ + 3 files changed, 192 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp +index becdae07..7a44b6cf 100644 +--- a/ggml/src/ggml-cpu/ops.cpp ++++ b/ggml/src/ggml-cpu/ops.cpp +@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32( + } + } + ++static void ggml_compute_forward_argsort_i32( ++ const ggml_compute_params * params, ++ ggml_tensor * dst) { ++ ++ const ggml_tensor * src0 = dst->src[0]; ++ ++ GGML_TENSOR_UNARY_OP_LOCALS ++ ++ GGML_ASSERT(nb0 == sizeof(int32_t)); ++ ++ const int ith = params->ith; ++ const int nth = params->nth; ++ ++ const int64_t nr = ggml_nrows(src0); ++ ++ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); ++ ++ for (int64_t i = ith; i < nr; i += nth) { ++ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); ++ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01); ++ ++ for (int64_t j = 0; j < ne0; j++) { ++ dst_data[j] = j; ++ } ++ ++ // C doesn't have a functional sort, so we do a bubble sort instead ++ for (int64_t j = 0; j < ne0; j++) { ++ for (int64_t k = j + 1; k < ne0; k++) { ++ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || ++ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { ++ int32_t tmp = dst_data[j]; ++ dst_data[j] = dst_data[k]; ++ dst_data[k] = tmp; ++ } ++ } ++ } ++ } ++} ++ + void ggml_compute_forward_argsort( + const ggml_compute_params * params, + ggml_tensor * dst) { +@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort( + { + ggml_compute_forward_argsort_f32(params, dst); + } break; ++ case GGML_TYPE_I32: ++ { ++ ggml_compute_forward_argsort_i32(params, dst); ++ } break; + default: + { + GGML_ABORT("fatal error"); +diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu +index 607ded85..53b02634 100644 +--- a/ggml/src/ggml-cuda/argsort.cu ++++ b/ggml/src/ggml-cuda/argsort.cu +@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co + } + } + ++ ++template ++static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) { ++ extern __shared__ int shared_mem[]; ++ int * indices = shared_mem; ++ ++ const int tid = threadIdx.x; ++ const int row = blockIdx.y; ++ ++ // Initialize all indices, handling the case where threads < ncols_pad ++ for (int i = tid; i < ncols_pad; i += blockDim.x) { ++ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices ++ } ++ __syncthreads(); ++ ++ // Bitonic sort ++ for (int k = 2; k <= ncols_pad; k *= 2) { ++ for (int j = k/2; j > 0; j /= 2) { ++ for (int i = tid; i < ncols_pad; i += blockDim.x) { ++ const int ij = i ^ j; ++ if (ij > i) { ++ // Only compare values within the actual data range ++ if (i < ncols && ij < ncols) { ++ if ((i & k) == 0) { ++ if (order == GGML_SORT_ORDER_ASC) { ++ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } else { ++ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } ++ } else { ++ if (order == GGML_SORT_ORDER_ASC) { ++ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } else { ++ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ } ++ ++ // Write sorted indices to output, only threads handling valid data ++ for (int i = tid; i < ncols; i += blockDim.x) { ++ dst[row * ncols + i] = indices[i]; ++ } ++} ++ ++static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { ++ // Bitonic sort requires ncols to be power of 2 ++ const int ncols_pad = next_power_of_2(ncols); ++ ++ // Ensure thread count doesn't exceed maximum (typically 1024) ++ const int max_threads = 1024; // This is the typical max for most GPUs ++ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad; ++ ++ const dim3 block_dims(threads_per_block, 1, 1); ++ const dim3 block_nums(1, nrows, 1); ++ const size_t shared_mem = ncols_pad * sizeof(int); ++ ++ // Check if shared memory size is within limits ++ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; ++ ++ // Instead of logging an error, use GGML_ASSERT with a descriptive message ++ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit"); ++ ++ // Launch kernels with the updated thread configuration ++ if (order == GGML_SORT_ORDER_ASC) { ++ k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); ++ } else if (order == GGML_SORT_ORDER_DESC) { ++ k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++} ++ ++ + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + +- GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + +@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + +- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ if (src0->type == GGML_TYPE_I32) { ++ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ } else { ++ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ } + } +diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu +index 2d46176e..47383486 100644 +--- a/ggml/src/ggml-cuda/cpy.cu ++++ b/ggml/src/ggml-cuda/cpy.cu +@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { + *dsti = *xi; + } + ++static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { ++ const int32_t * xi = (const int32_t *) cxi; ++ int32_t * dsti = (int32_t *) cdsti; ++ ++ *dsti = *xi; ++} ++ + template + static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, +@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in + cpy_1(cx + x_offset, cdst + dst_offset); + } + ++// First, add this template function after the other template functions ++template ++static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne, ++ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, ++ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, ++ const int nb12, const int nb13) { ++ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; ++ ++ if (i >= ne) { ++ return; ++ } ++ ++ const int64_t i03 = i/(ne00 * ne01 * ne02); ++ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); ++ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; ++ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; ++ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; ++ ++ const int64_t i13 = i/(ne10 * ne11 * ne12); ++ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); ++ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; ++ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; ++ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; ++ ++ cpy_1(cx + x_offset, cdst + dst_offset); ++} ++ ++// Then modify the ggml_cpy_i32_i32_cuda function to use the new template ++static void ggml_cpy_i32_i32_cuda( ++ const char * cx, char * cdst, const int ne, ++ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, ++ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) { ++ ++ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; ++ cpy_i32_i32<<>> ++ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); ++} ++ + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; +@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ++ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { ++ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); +@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; ++ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { ++ return (void*) cpy_i32_i32; + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ml/backend.go b/ml/backend.go index f84a99845..cb32d8185 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -119,6 +119,21 @@ type Context interface { Layer(int) Context } +// RopeOptions contains optional parameters for RoPE function +type RopeOptions struct { + OriginalContextLen uint32 +} + +// RopeOption defines a function that modifies RopeOpts +type RopeOption func(*RopeOptions) + +// WithContextLen sets a custom context length +func WithContextLen(len uint32) RopeOption { + return func(opts *RopeOptions) { + opts.OriginalContextLen = len + } +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -144,7 +159,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Sin(ctx Context) Tensor @@ -172,6 +187,7 @@ type Tensor interface { Duplicate(ctx Context) Tensor TopK(ctx Context, k int) Tensor + Argsort(ctx Context) Tensor } // ScaledDotProductAttention implements a fused attention diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index e1aa687c8..1ba079838 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1060,7 +1060,17 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor { + // Default options + opts := &ml.RopeOptions{ + OriginalContextLen: 131072, + } + + // Apply any provided options + for _, option := range options { + option(opts) + } + if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -1073,16 +1083,19 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi return &Tensor{ b: t.b, t: C.ggml_rope_ext( - ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, + ctx.(*Context).ctx, + dequant, + positionIDs.(*Tensor).t, + ropeFactors.(*Tensor).t, C.int(ropeDim), C.int(ropeType), - 131072, // YaRN n_ctx_train + C.int(opts.OriginalContextLen), C.float(ropeBase), C.float(ropeScale), - 0., // YaRN ext_factor - 1., // YaRN attn_factor - 32., // YaRN beta_fast - 1., // YaRN beta_slow + C.float(0.0), + C.float(1.0), + C.float(32.0), + C.float(1.0), ), } } @@ -1176,3 +1189,10 @@ func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor { t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)), } } + +func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC), + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index 955fec59a..654e2f280 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32( } } +static void ggml_compute_forward_argsort_i32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(int32_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { @@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; + case GGML_TYPE_I32: + { + ggml_compute_forward_argsort_i32(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu index 607ded855..53b02634c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu @@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co } } + +template +static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) { + extern __shared__ int shared_mem[]; + int * indices = shared_mem; + + const int tid = threadIdx.x; + const int row = blockIdx.y; + + // Initialize all indices, handling the case where threads < ncols_pad + for (int i = tid; i < ncols_pad; i += blockDim.x) { + indices[i] = i < ncols ? i : 0; // Use 0 for padding indices + } + __syncthreads(); + + // Bitonic sort + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k/2; j > 0; j /= 2) { + for (int i = tid; i < ncols_pad; i += blockDim.x) { + const int ij = i ^ j; + if (ij > i) { + // Only compare values within the actual data range + if (i < ncols && ij < ncols) { + if ((i & k) == 0) { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } else { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } + } + } + } + __syncthreads(); + } + } + + // Write sorted indices to output, only threads handling valid data + for (int i = tid; i < ncols; i += blockDim.x) { + dst[row * ncols + i] = indices[i]; + } +} + +static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // Bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + // Ensure thread count doesn't exceed maximum (typically 1024) + const int max_threads = 1024; // This is the typical max for most GPUs + const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad; + + const dim3 block_dims(threads_per_block, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + // Check if shared memory size is within limits + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + // Instead of logging an error, use GGML_ASSERT with a descriptive message + GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit"); + + // Launch kernels with the updated thread configuration + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else { + GGML_ABORT("fatal error"); + } +} + + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(src0)); @@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + if (src0->type == GGML_TYPE_I32) { + argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu index d027271fc..4abd01d79 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu @@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { *dsti = *xi; } +static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { + const int32_t * xi = (const int32_t *) cxi; + int32_t * dsti = (int32_t *) cdsti; + + *dsti = *xi; +} + template static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in cpy_1(cx + x_offset, cdst + dst_offset); } +// First, add this template function after the other template functions +template +static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +// Then modify the ggml_cpy_i32_i32_cuda function to use the new template +static void ggml_cpy_i32_i32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_i32_i32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q8_0 * dsti = (block_q8_0 *) cdsti; @@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_i32_i32; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/model/models/models.go b/model/models/models.go index 73b4c53a5..133e51761 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,4 +7,5 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/qwen25vl" ) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go new file mode 100644 index 000000000..9d243c30f --- /dev/null +++ b/model/models/qwen25vl/model.go @@ -0,0 +1,187 @@ +package qwen25vl + +import ( + "bytes" + "fmt" + "image" + "slices" + "sync" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.BytePairEncoding + + *TextModel + *VisionModel `gguf:"v,vision"` + + ImageProcessor +} + +// Implement MultimodalProcessor interface +var _ model.MultimodalProcessor = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + m := &Model{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + TextModel: NewTextModel(c), + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + } + + m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) + + return m, nil +} + +func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) { + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, nil, err + } + + f32s, grid, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, nil, err + } + + // Calculate tensor dimensions + patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * + m.ImageProcessor.patchSize * m.ImageProcessor.patchSize + numPatches := grid.Temporal * grid.Height * grid.Width + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) + if err != nil { + return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err) + } + + return pixelValues, grid, nil +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + pixels, grid, err := m.PixelValues(ctx, multimodalData) + if err != nil { + return nil, err + } + + visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) + return &chunks{Model: m, Tensor: visionOutputs}, nil +} + +type chunks struct { + *Model + ml.Tensor + + dataOnce sync.Once + data []float32 +} + +type chunk struct { + *chunks + s, n int +} + +func (r *chunk) floats() []float32 { + r.dataOnce.Do(func() { + temp := r.Backend().NewContext() + defer temp.Close() + temp.Forward(r.Tensor).Compute(r.Tensor) + r.data = r.Floats() + }) + + return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] +} + +// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + var result []input.Input + + var ( + imageToken int32 = 151655 + visionStartToken int32 = 151652 + visionEndToken int32 = 151653 + ) + + nImg := 0 + for _, inp := range inputs { + if inp.Multimodal == nil { + // If not a multimodal input, add it to the result unchanged + result = append(result, inp) + } else { + // Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix + // the image tokens with a prompt, so we add a prefix here + nImg++ + pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true) + if err != nil { + return nil, fmt.Errorf("failed to encode image prompt: %w", err) + } + for i := range pre { + result = append(result, input.Input{Token: pre[i]}) + } + + // This is an image token with multimodal data + chunksData := inp.Multimodal.(*chunks) + patchesPerChunk := chunksData.Dim(1) + + // First add the vision start token + result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2}) + + // Add the image token with the multimodal tensor data at the first position + // Create a chunk with proper s and n values + result = append(result, input.Input{ + Token: imageToken, + Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk}, + MultimodalHash: inp.MultimodalHash, + SameBatch: patchesPerChunk, + }) + + // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) + result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) + + result = append(result, input.Input{Token: visionEndToken}) + } + } + + return result, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) +} + +func init() { + model.Register("qwen25vl", New) +} diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go new file mode 100644 index 000000000..6b062f8c5 --- /dev/null +++ b/model/models/qwen25vl/model_text.go @@ -0,0 +1,155 @@ +package qwen25vl + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" +) + +type TextOptions struct { + ctxLen, hiddenSize, numHeads, numKVHeads int + eps, ropeBase, ropeScale float32 + ropeDim, defaultContextLen uint32 +} + +type TextModel struct { + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +func NewTextModel(c fs.Config) *TextModel { + m := TextModel{ + Layers: make([]Layer, c.Uint("block_count")), + TextOptions: &TextOptions{ + ctxLen: int(c.Uint("context_length")), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count", 128), + defaultContextLen: c.Uint("context_length", 128000), + }, + } + + return &m +} + +// SelfAttention implements the multi-head self-attention mechanism +// with separate projections for query, key, value and output transformations +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := opts.hiddenSize / opts.numHeads + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, kqv) +} + +// Shift applies rotary position embeddings to the key tensor for causal attention caching +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil +} + +// MLP implements the feed-forward network component with SwiGLU activation +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { + // Apply SwiGLU activation gating + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + // Project back to hidden dimension + return mlp.Down.Forward(ctx, hiddenState) +} + +// Layer represents a single transformer layer combining self-attention and feed-forward components +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + // Self-attention branch with residual connection + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + // Feed-forward branch with residual connection + residual = hiddenState + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { + // Initial token embedding + hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) + + for _, mi := range batch.Multimodal { + f32s := mi.Multimodal.(*chunk).floats() + img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) + if err != nil { + panic(err) + } + + ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) + } + + // Process through transformer layers + for i, layer := range m.Layers { + cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go new file mode 100644 index 000000000..01eef392b --- /dev/null +++ b/model/models/qwen25vl/model_vision.go @@ -0,0 +1,391 @@ +package qwen25vl + +import ( + "fmt" + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +// We only support batch size of 1 +var batchSize int = 1 + +func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { + x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) + x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) + return x2.Neg(ctx).Concat(ctx, x1, 0) +} + +func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { + return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +} + +func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { + // Create a flat slice for the mask (all -inf initially to block all attention) + flat := make([]float32, seqLength*seqLength) + for i := range flat { + flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention + } + + // Fill in the mask with zeros for tokens that CAN attend to each other + for i := 1; i < len(bounds); i++ { + start := bounds[i-1] + end := bounds[i] + + // Enable attention within this sequence block by setting values to 0 + for row := start; row < end; row++ { + for col := start; col < end; col++ { + idx := row*seqLength + col + flat[idx] = 0.0 // 0 allows attention, -inf blocks it + } + } + } + + mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + if err != nil { + panic(err) + } + // Reshape to match [seqLength, seqLength, 1] for broadcasting + mask = mask.Reshape(ctx, seqLength, seqLength, 1) + + return mask +} + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_out"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + + query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + + // Scale factor for scaled dot-product attention + scale := 1.0 / math.Sqrt(float64(opts.headDim)) + + // Scaled dot-product attention + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) + kq = kq.Scale(ctx, scale) + if mask != nil { + kq = kq.Add(ctx, mask) + } + kq = kq.Softmax(ctx) + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + + return sa.Output.Forward(ctx, attention) +} + +// VisionMLP implements the multi-layer perceptron +type VisionMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { + // Using activation as specified in config (likely GELU or SiLU/Swish) + gateOutput := mlp.Gate.Forward(ctx, hiddenStates) + upOutput := mlp.Up.Forward(ctx, hiddenStates) + hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput) + + return mlp.Down.Forward(ctx, hiddenStates) +} + +type VisionEncoderLayer struct { + Norm1 *nn.RMSNorm `gguf:"ln1"` + SelfAttention *VisionSelfAttention + Norm2 *nn.RMSNorm `gguf:"ln2"` + MLP *VisionMLP +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +// VisionModelOptions contains configuration options +type VisionModelOptions struct { + hiddenSize int + numHeads int + headDim int + patchSize int + numChannels int + eps float32 + ropeTheta float32 + spatialMergeSize int + windowSize int + fullAttnBlocks []int32 + temporalPatchSize int +} + +type PatchEmbedding struct { + PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"` + PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` +} + +func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor { + numPatches := pixelValues.Shape()[1] + + // Reshape the input tensor to match the expected dimensions + pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches) + + // Permute the tensor to bring the temporal dimension to the front + pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + // Split the tensor into parts for the temporal convolutions + in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) + in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches) + in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) + in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches) + + s0, s1 := opts.patchSize, opts.patchSize // Use full stride + p0, p1 := 0, 0 // padding + d0, d1 := 1, 1 // dilation + out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1) + out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1) + + // Add the outputs from the two temporal convolutions + out := out0.Add(ctx, out1) + + // Reshape the output tensor to match the expected dimensions + return out.Reshape(ctx, opts.hiddenSize, numPatches) +} + +// VisionPatchMerger implements patch merging for the Qwen vision model +type VisionPatchMerger struct { + LNQ *nn.RMSNorm `gguf:"ln_q"` + MLP0 *nn.Linear `gguf:"mlp.0"` + MLP2 *nn.Linear `gguf:"mlp.2"` +} + +// Forward computes patch merging for the vision model +func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor { + normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps) + + hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize) + + // Reshape the normalized output to view the hidden size dimension + reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize) + hidden := pm.MLP0.Forward(ctx, reshaped) + activated := hidden.GELU(ctx) + + output := pm.MLP2.Forward(ctx, activated) + + return output +} + +// VisionModel implements the Qwen vision model +type VisionModel struct { + PatchEmbedding *PatchEmbedding + Layers []VisionEncoderLayer `gguf:"blk"` + PatchMerger *VisionPatchMerger `gguf:"merger"` + + *VisionModelOptions +} + +// Forward computes the vision model for an input tensor +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor { + // Extract patch embeddings + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) + + positionEmbedding := m.PositionalEmbedding(ctx, grid) + + windowIndex, bounds := m.WindowIndex(ctx, grid) + + spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit) + hiddenStates = hiddenStates.Rows(ctx, windowIndex) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit) + + positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit) + positionEmbedding = positionEmbedding.Rows(ctx, windowIndex) + positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit) + positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0) + + cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) + cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) + sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) + + mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads) + // Apply encoder layers + for i, layer := range m.Layers { + if slices.Contains(m.fullAttnBlocks, int32(i)) { + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions) + } else { + hiddenStates = layer.Forward( + ctx, + hiddenStates, + cos, + sin, + mask, + m.VisionModelOptions, + ) + } + } + + hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) + reverseWindowIndex := windowIndex.Argsort(ctx) + return hiddenStates.Rows(ctx, reverseWindowIndex) +} + +// WindowIndex divides the grid into windows and returns: +// 1. A tensor containing flattened indices of all grid points organized by windows +// 2. A slice of boundaries that mark where each window's data begins and ends +// in the flattened representation, scaled by spatialMergeSize squared +// +// The boundaries slice always starts with 0 and contains cumulative ending +// positions for each window, allowing downstream processing to identify +// window boundaries in the tensor data. +func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { + vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize + + llmGridH := grid.Height / m.spatialMergeSize + llmGridW := grid.Width / m.spatialMergeSize + + // Calculate window parameters + numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize))) + numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize))) + + // Initialize index_new slice + var index []int32 + + // Initialize bounds with the first element as 0 + bounds := []int{0} + totalSeqLen := 0 + + // Process each window without padding + for wh := range numWindowsH { + for ww := range numWindowsW { + // Calculate window boundaries + hStart := wh * vitMergerWindowSize + wStart := ww * vitMergerWindowSize + hEnd := min(hStart+vitMergerWindowSize, llmGridH) + wEnd := min(wStart+vitMergerWindowSize, llmGridW) + + // Calculate sequence length for this window + seqLen := (hEnd - hStart) * (wEnd - wStart) + + // Collect indices for this window + for h := hStart; h < hEnd; h++ { + for w := wStart; w < wEnd; w++ { + index = append(index, int32(h*llmGridW+w)) + } + } + + totalSeqLen += seqLen + bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0]) + } + } + + t, err := ctx.Input().FromIntSlice(index, len(index)) + if err != nil { + panic(err) + } + + return t, bounds +} + +// PositionalEmbedding generates rotary position embeddings for attention mechanisms +func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { + dim := m.headDim / 2 + freq := dim / 2 + theta := float64(m.ropeTheta) + merge := m.spatialMergeSize + + // Create frequency patterns for position encoding + maxGridSize := max(grid.Height, grid.Width) + freqVals := make([]float32, freq*maxGridSize) + for i := range maxGridSize { + for j := range freq { + freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) + } + } + freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) + if err != nil { + panic(fmt.Errorf("failed to create tensor from frequencies: %w", err)) + } + + // Create position coordinates (y,x pairs) for the grid + // In PyTorch: Equivalent to generating position ids with torch.arange() + coords := make([]int32, 0, grid.Height*grid.Width*2) + for y := range grid.Height { + for x := range grid.Width { + coords = append(coords, int32(y), int32(x)) + } + } + pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) + if err != nil { + panic(fmt.Errorf("failed to create tensor from positions: %w", err)) + } + + // Reshape and permute positions to match spatial merging pattern + pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) + pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge) + pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge) + + // Use position indices to look up corresponding frequency values + positionalEmbedding := freqs.Rows(ctx, pos) + positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2) + return positionalEmbedding +} + +// newVisionModel creates a new instance of the Qwen vision model +func newVisionModel(c fs.Config) *VisionModel { + patchSize := int(c.Uint("vision.patch_size", 14)) + hiddenSize := int(c.Uint("vision.embedding_length", 1280)) + numHeads := int(c.Uint("vision.attention.head_count", 16)) + numChannels := int(c.Uint("vision.num_channels", 3)) + eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6) + ropeTheta := c.Float("vision.rope.freq_base", 10000.0) + spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2)) + windowSize := int(c.Uint("vision.window_size", 112)) + fullAttnBlocks := c.Ints("qwen25vl.vision.fullatt_block_indexes", []int32{7, 15, 23, 31}) + temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2)) + + model := &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: hiddenSize / numHeads, + patchSize: patchSize, + numChannels: numChannels, + eps: eps, + ropeTheta: ropeTheta, + spatialMergeSize: spatialMergeSize, + windowSize: windowSize, + temporalPatchSize: temporalPatchSize, + fullAttnBlocks: fullAttnBlocks, + }, + } + + return model +} diff --git a/model/models/qwen25vl/process_image.go b/model/models/qwen25vl/process_image.go new file mode 100644 index 000000000..dc91bdea5 --- /dev/null +++ b/model/models/qwen25vl/process_image.go @@ -0,0 +1,184 @@ +package qwen25vl + +import ( + "fmt" + "image" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +// ImageProcessor contains configuration for the Qwen 2.5 VL image processing +type ImageProcessor struct { + numChannels int + patchSize int + temporalPatchSize int + mergeSize int + minPixels int + maxPixels int + factor int + rescaleFactor float32 + imageMean []float32 + imageStd []float32 +} + +// newImageProcessor creates a new image processor with default values +func newImageProcessor(c fs.Config) ImageProcessor { + patchSize := int(c.Uint("vision.patch_size", 14)) + mergeSize := int(c.Uint("vision.spatial_merge_size", 2)) + + return ImageProcessor{ + numChannels: int(c.Uint("vision.num_channels", 3)), // not set + patchSize: patchSize, + temporalPatchSize: 2, + mergeSize: mergeSize, + minPixels: 56 * 56, + maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit + factor: patchSize * mergeSize, + rescaleFactor: 1.0 / 255.0, + imageMean: imageproc.ClipDefaultMean[:], + imageStd: imageproc.ClipDefaultSTD[:], + } +} + +// SmartResize implements the smart resize algorithm +func (p *ImageProcessor) SmartResize(height, width int) (int, int) { + factor := p.factor + + if height < factor || width < factor { + panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor)) + } else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 { + panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio)) + } + + round := func(x float64) int { return int(math.RoundToEven(x)) } + + hBar := round(float64(height)/float64(factor)) * factor + wBar := round(float64(width)/float64(factor)) * factor + + if hBar*wBar > p.maxPixels { + beta := math.Sqrt(float64(height*width) / float64(p.maxPixels)) + + hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor + wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor + } else if hBar*wBar < p.minPixels { + beta := math.Sqrt(float64(p.minPixels) / float64(height*width)) + + hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor + wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor + } + + return hBar, wBar +} + +type Grid struct { + Height int + Width int + Temporal int +} + +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) { + origWidth := img.Bounds().Dx() + origHeight := img.Bounds().Dy() + + // Calculate smart resize dimensions + resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth) + + // Resize image using existing functions + resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) + + normalizedPixels := imageproc.Normalize( + resizedImg, + [3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]}, + [3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]}, + true, // rescale + true, // channelFirst + ) + + // Calculate grid dimensions + grid := &Grid{ + Height: resizedHeight / p.patchSize, + Width: resizedWidth / p.patchSize, + Temporal: 1, // For single images, temporal dimension is 1 + } + + patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid) + if err != nil { + return nil, nil, fmt.Errorf("failed to create patches: %v", err) + } + + // Return patches and grid dimensions + return patches, grid, nil +} + +func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) { + channels := p.numChannels + patchSize := p.patchSize + mergeSize := p.mergeSize + temporalPatchSize := p.temporalPatchSize + + // Calculate output dimensions + numPatches := grid.Temporal * grid.Height * grid.Width + patchDim := channels * temporalPatchSize * patchSize * patchSize + + result := make([]float32, numPatches*patchDim) + patchIndex := 0 + + // Single temporal frame handling (copies to all frames) + for range grid.Temporal { + for h := 0; h < grid.Height; h += mergeSize { + for w := 0; w < grid.Width; w += mergeSize { + // Handle the 2x2 merged patches + for mh := range mergeSize { + for mw := range mergeSize { + baseOffset := patchIndex * patchDim + + // Extract patch data for first temporal frame + for c := range channels { + channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize) + + for py := range patchSize { + for px := range patchSize { + // Calculate source pixel coordinates + y := (h+mh)*patchSize + py + x := (w+mw)*patchSize + px + + // Source index in input tensor (CHW format) + srcIdx := c*height*width + y*width + x + + // Destination index in first temporal frame + dstIdx := channelOffset + (py * patchSize) + px + + if srcIdx < len(pixels) && dstIdx < len(result) { + result[dstIdx] = pixels[srcIdx] + } + } + } + } + + // Copy first temporal frame to all other frames + if temporalPatchSize > 1 { + for c := range channels { + channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize) + firstFrameOffset := channelOffset + frameSize := patchSize * patchSize + + // Copy first frame to all other frames + for tp := 1; tp < temporalPatchSize; tp++ { + currentFrameOffset := channelOffset + (tp * frameSize) + copy(result[currentFrameOffset:currentFrameOffset+frameSize], + result[firstFrameOffset:firstFrameOffset+frameSize]) + } + } + } + + patchIndex++ + } + } + } + } + } + + return result, nil +}