From 3308bff137dc6172b430ac2acfee2f0ac1850277 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 10 May 2025 17:34:53 -0700 Subject: [PATCH] add i32 copy and argsort for cuda --- llama/patches/0007-add-unpad-operator.patch | 4 +- ...18-add-argsort-and-cuda-copy-for-i32.patch | 277 ++++++++++++++++++ .../0018-add-argsort-for-int32_t.patch | 70 ----- ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu | 102 ++++++- ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu | 49 ++++ 5 files changed, 428 insertions(+), 74 deletions(-) create mode 100644 llama/patches/0018-add-argsort-and-cuda-copy-for-i32.patch delete mode 100644 llama/patches/0018-add-argsort-for-int32_t.patch diff --git a/llama/patches/0007-add-unpad-operator.patch b/llama/patches/0007-add-unpad-operator.patch index fc45aeff4..d3ede9025 100644 --- a/llama/patches/0007-add-unpad-operator.patch +++ b/llama/patches/0007-add-unpad-operator.patch @@ -236,7 +236,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1b56f858..7641247e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -347,6 +347,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte +@@ -341,6 +341,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, @@ -244,7 +244,7 @@ index 1b56f858..7641247e 100644 GGML_METAL_KERNEL_TYPE_ARANGE_F32, GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, -@@ -1294,6 +1295,7 @@ @implementation GGMLMetalClass +@@ -1277,6 +1278,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); diff --git a/llama/patches/0018-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0018-add-argsort-and-cuda-copy-for-i32.patch new file mode 100644 index 000000000..b71295c76 --- /dev/null +++ b/llama/patches/0018-add-argsort-and-cuda-copy-for-i32.patch @@ -0,0 +1,277 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Thu, 1 May 2025 13:45:12 -0700 +Subject: [PATCH] add argsort and cuda copy for i32 + +--- + ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++++++ + ggml/src/ggml-cuda/argsort.cu | 102 +++++++++++++++++++++++++++++++++- + ggml/src/ggml-cuda/cpy.cu | 49 ++++++++++++++++ + 3 files changed, 192 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp +index becdae07..7a44b6cf 100644 +--- a/ggml/src/ggml-cpu/ops.cpp ++++ b/ggml/src/ggml-cpu/ops.cpp +@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32( + } + } + ++static void ggml_compute_forward_argsort_i32( ++ const ggml_compute_params * params, ++ ggml_tensor * dst) { ++ ++ const ggml_tensor * src0 = dst->src[0]; ++ ++ GGML_TENSOR_UNARY_OP_LOCALS ++ ++ GGML_ASSERT(nb0 == sizeof(int32_t)); ++ ++ const int ith = params->ith; ++ const int nth = params->nth; ++ ++ const int64_t nr = ggml_nrows(src0); ++ ++ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); ++ ++ for (int64_t i = ith; i < nr; i += nth) { ++ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); ++ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01); ++ ++ for (int64_t j = 0; j < ne0; j++) { ++ dst_data[j] = j; ++ } ++ ++ // C doesn't have a functional sort, so we do a bubble sort instead ++ for (int64_t j = 0; j < ne0; j++) { ++ for (int64_t k = j + 1; k < ne0; k++) { ++ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || ++ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { ++ int32_t tmp = dst_data[j]; ++ dst_data[j] = dst_data[k]; ++ dst_data[k] = tmp; ++ } ++ } ++ } ++ } ++} ++ + void ggml_compute_forward_argsort( + const ggml_compute_params * params, + ggml_tensor * dst) { +@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort( + { + ggml_compute_forward_argsort_f32(params, dst); + } break; ++ case GGML_TYPE_I32: ++ { ++ ggml_compute_forward_argsort_i32(params, dst); ++ } break; + default: + { + GGML_ABORT("fatal error"); +diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu +index 607ded85..53b02634 100644 +--- a/ggml/src/ggml-cuda/argsort.cu ++++ b/ggml/src/ggml-cuda/argsort.cu +@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co + } + } + ++ ++template ++static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) { ++ extern __shared__ int shared_mem[]; ++ int * indices = shared_mem; ++ ++ const int tid = threadIdx.x; ++ const int row = blockIdx.y; ++ ++ // Initialize all indices, handling the case where threads < ncols_pad ++ for (int i = tid; i < ncols_pad; i += blockDim.x) { ++ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices ++ } ++ __syncthreads(); ++ ++ // Bitonic sort ++ for (int k = 2; k <= ncols_pad; k *= 2) { ++ for (int j = k/2; j > 0; j /= 2) { ++ for (int i = tid; i < ncols_pad; i += blockDim.x) { ++ const int ij = i ^ j; ++ if (ij > i) { ++ // Only compare values within the actual data range ++ if (i < ncols && ij < ncols) { ++ if ((i & k) == 0) { ++ if (order == GGML_SORT_ORDER_ASC) { ++ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } else { ++ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } ++ } else { ++ if (order == GGML_SORT_ORDER_ASC) { ++ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } else { ++ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { ++ int tmp = indices[i]; ++ indices[i] = indices[ij]; ++ indices[ij] = tmp; ++ } ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ } ++ ++ // Write sorted indices to output, only threads handling valid data ++ for (int i = tid; i < ncols; i += blockDim.x) { ++ dst[row * ncols + i] = indices[i]; ++ } ++} ++ ++static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { ++ // Bitonic sort requires ncols to be power of 2 ++ const int ncols_pad = next_power_of_2(ncols); ++ ++ // Ensure thread count doesn't exceed maximum (typically 1024) ++ const int max_threads = 1024; // This is the typical max for most GPUs ++ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad; ++ ++ const dim3 block_dims(threads_per_block, 1, 1); ++ const dim3 block_nums(1, nrows, 1); ++ const size_t shared_mem = ncols_pad * sizeof(int); ++ ++ // Check if shared memory size is within limits ++ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; ++ ++ // Instead of logging an error, use GGML_ASSERT with a descriptive message ++ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit"); ++ ++ // Launch kernels with the updated thread configuration ++ if (order == GGML_SORT_ORDER_ASC) { ++ k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); ++ } else if (order == GGML_SORT_ORDER_DESC) { ++ k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++} ++ ++ + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + +- GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + +@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + +- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ if (src0->type == GGML_TYPE_I32) { ++ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ } else { ++ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ } + } +diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu +index 2d46176e..47383486 100644 +--- a/ggml/src/ggml-cuda/cpy.cu ++++ b/ggml/src/ggml-cuda/cpy.cu +@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { + *dsti = *xi; + } + ++static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { ++ const int32_t * xi = (const int32_t *) cxi; ++ int32_t * dsti = (int32_t *) cdsti; ++ ++ *dsti = *xi; ++} ++ + template + static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, +@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in + cpy_1(cx + x_offset, cdst + dst_offset); + } + ++// First, add this template function after the other template functions ++template ++static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne, ++ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, ++ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, ++ const int nb12, const int nb13) { ++ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; ++ ++ if (i >= ne) { ++ return; ++ } ++ ++ const int64_t i03 = i/(ne00 * ne01 * ne02); ++ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); ++ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; ++ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; ++ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; ++ ++ const int64_t i13 = i/(ne10 * ne11 * ne12); ++ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); ++ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; ++ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; ++ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; ++ ++ cpy_1(cx + x_offset, cdst + dst_offset); ++} ++ ++// Then modify the ggml_cpy_i32_i32_cuda function to use the new template ++static void ggml_cpy_i32_i32_cuda( ++ const char * cx, char * cdst, const int ne, ++ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, ++ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) { ++ ++ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; ++ cpy_i32_i32<<>> ++ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); ++} ++ + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; +@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ++ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { ++ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); +@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; ++ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { ++ return (void*) cpy_i32_i32; + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/llama/patches/0018-add-argsort-for-int32_t.patch b/llama/patches/0018-add-argsort-for-int32_t.patch deleted file mode 100644 index 9ac2b87ee..000000000 --- a/llama/patches/0018-add-argsort-for-int32_t.patch +++ /dev/null @@ -1,70 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Michael Yang -Date: Thu, 1 May 2025 13:45:12 -0700 -Subject: [PATCH] add argsort for int32_t - ---- - ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++++++++++++++++++++++++++++++ - 1 file changed, 43 insertions(+) - -diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 66b8da68..1ad571d3 100644 ---- a/ggml/src/ggml-cpu/ops.cpp -+++ b/ggml/src/ggml-cpu/ops.cpp -@@ -6718,6 +6718,45 @@ static void ggml_compute_forward_argsort_f32( - } - } - -+static void ggml_compute_forward_argsort_i32( -+ const ggml_compute_params * params, -+ ggml_tensor * dst) { -+ -+ const ggml_tensor * src0 = dst->src[0]; -+ -+ GGML_TENSOR_UNARY_OP_LOCALS -+ -+ GGML_ASSERT(nb0 == sizeof(int32_t)); -+ -+ const int ith = params->ith; -+ const int nth = params->nth; -+ -+ const int64_t nr = ggml_nrows(src0); -+ -+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); -+ -+ for (int64_t i = ith; i < nr; i += nth) { -+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); -+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01); -+ -+ for (int64_t j = 0; j < ne0; j++) { -+ dst_data[j] = j; -+ } -+ -+ // C doesn't have a functional sort, so we do a bubble sort instead -+ for (int64_t j = 0; j < ne0; j++) { -+ for (int64_t k = j + 1; k < ne0; k++) { -+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || -+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { -+ int32_t tmp = dst_data[j]; -+ dst_data[j] = dst_data[k]; -+ dst_data[k] = tmp; -+ } -+ } -+ } -+ } -+} -+ - void ggml_compute_forward_argsort( - const ggml_compute_params * params, - ggml_tensor * dst) { -@@ -6729,6 +6768,10 @@ void ggml_compute_forward_argsort( - { - ggml_compute_forward_argsort_f32(params, dst); - } break; -+ case GGML_TYPE_I32: -+ { -+ ggml_compute_forward_argsort_i32(params, dst); -+ } break; - default: - { - GGML_ABORT("fatal error"); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu index 607ded855..53b02634c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu @@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co } } + +template +static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) { + extern __shared__ int shared_mem[]; + int * indices = shared_mem; + + const int tid = threadIdx.x; + const int row = blockIdx.y; + + // Initialize all indices, handling the case where threads < ncols_pad + for (int i = tid; i < ncols_pad; i += blockDim.x) { + indices[i] = i < ncols ? i : 0; // Use 0 for padding indices + } + __syncthreads(); + + // Bitonic sort + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k/2; j > 0; j /= 2) { + for (int i = tid; i < ncols_pad; i += blockDim.x) { + const int ij = i ^ j; + if (ij > i) { + // Only compare values within the actual data range + if (i < ncols && ij < ncols) { + if ((i & k) == 0) { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } else { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } + } + } + } + __syncthreads(); + } + } + + // Write sorted indices to output, only threads handling valid data + for (int i = tid; i < ncols; i += blockDim.x) { + dst[row * ncols + i] = indices[i]; + } +} + +static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // Bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + // Ensure thread count doesn't exceed maximum (typically 1024) + const int max_threads = 1024; // This is the typical max for most GPUs + const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad; + + const dim3 block_dims(threads_per_block, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + // Check if shared memory size is within limits + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + // Instead of logging an error, use GGML_ASSERT with a descriptive message + GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit"); + + // Launch kernels with the updated thread configuration + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else { + GGML_ABORT("fatal error"); + } +} + + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(src0)); @@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + if (src0->type == GGML_TYPE_I32) { + argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu index d027271fc..4abd01d79 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu @@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { *dsti = *xi; } +static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { + const int32_t * xi = (const int32_t *) cxi; + int32_t * dsti = (int32_t *) cdsti; + + *dsti = *xi; +} + template static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in cpy_1(cx + x_offset, cdst + dst_offset); } +// First, add this template function after the other template functions +template +static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +// Then modify the ggml_cpy_i32_i32_cuda function to use the new template +static void ggml_cpy_i32_i32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_i32_i32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q8_0 * dsti = (block_q8_0 *) cdsti; @@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_i32_i32; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type));