mirror of https://github.com/ollama/ollama
531 lines
20 KiB
Diff
531 lines
20 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: Michael Yang <git@mxy.ng>
|
|
Date: Thu, 1 May 2025 13:45:12 -0700
|
|
Subject: [PATCH] add argsort and cuda copy for i32
|
|
|
|
---
|
|
ggml/src/ggml-cpu/ops.cpp | 43 ++++++
|
|
ggml/src/ggml-cuda/argsort.cu | 122 +++++++++++++--
|
|
ggml/src/ggml-cuda/cpy-utils.cuh | 6 +
|
|
ggml/src/ggml-cuda/cpy.cu | 40 +++++
|
|
ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++
|
|
5 files changed, 414 insertions(+), 12 deletions(-)
|
|
|
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
|
index 303278397..7d1733adb 100644
|
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
|
@@ -7932,6 +7932,45 @@ static void ggml_compute_forward_argsort_f32(
|
|
}
|
|
}
|
|
|
|
+static void ggml_compute_forward_argsort_i32(
|
|
+ const ggml_compute_params * params,
|
|
+ ggml_tensor * dst) {
|
|
+
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
+
|
|
+ GGML_TENSOR_UNARY_OP_LOCALS
|
|
+
|
|
+ GGML_ASSERT(nb0 == sizeof(int32_t));
|
|
+
|
|
+ const int ith = params->ith;
|
|
+ const int nth = params->nth;
|
|
+
|
|
+ const int64_t nr = ggml_nrows(src0);
|
|
+
|
|
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
+
|
|
+ for (int64_t i = ith; i < nr; i += nth) {
|
|
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
|
|
+
|
|
+ for (int64_t j = 0; j < ne0; j++) {
|
|
+ dst_data[j] = j;
|
|
+ }
|
|
+
|
|
+ // C doesn't have a functional sort, so we do a bubble sort instead
|
|
+ for (int64_t j = 0; j < ne0; j++) {
|
|
+ for (int64_t k = j + 1; k < ne0; k++) {
|
|
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
|
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
|
+ int32_t tmp = dst_data[j];
|
|
+ dst_data[j] = dst_data[k];
|
|
+ dst_data[k] = tmp;
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+}
|
|
+
|
|
void ggml_compute_forward_argsort(
|
|
const ggml_compute_params * params,
|
|
ggml_tensor * dst) {
|
|
@@ -7943,6 +7982,10 @@ void ggml_compute_forward_argsort(
|
|
{
|
|
ggml_compute_forward_argsort_f32(params, dst);
|
|
} break;
|
|
+ case GGML_TYPE_I32:
|
|
+ {
|
|
+ ggml_compute_forward_argsort_i32(params, dst);
|
|
+ } break;
|
|
default:
|
|
{
|
|
GGML_ABORT("fatal error");
|
|
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
|
|
index da9652c3b..b82be371c 100644
|
|
--- a/ggml/src/ggml-cuda/argsort.cu
|
|
+++ b/ggml/src/ggml-cuda/argsort.cu
|
|
@@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
|
|
}
|
|
}
|
|
|
|
+
|
|
+template<ggml_sort_order order>
|
|
+static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
|
|
+ extern __shared__ int shared_mem[];
|
|
+ int * indices = shared_mem;
|
|
+
|
|
+ const int tid = threadIdx.x;
|
|
+ const int row = blockIdx.y;
|
|
+
|
|
+ // Initialize all indices, handling the case where threads < ncols_pad
|
|
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
|
|
+ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ // Bitonic sort
|
|
+ for (int k = 2; k <= ncols_pad; k *= 2) {
|
|
+ for (int j = k/2; j > 0; j /= 2) {
|
|
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
|
|
+ const int ij = i ^ j;
|
|
+ if (ij > i) {
|
|
+ // Only compare values within the actual data range
|
|
+ if (i < ncols && ij < ncols) {
|
|
+ if ((i & k) == 0) {
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
|
|
+ int tmp = indices[i];
|
|
+ indices[i] = indices[ij];
|
|
+ indices[ij] = tmp;
|
|
+ }
|
|
+ } else {
|
|
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
|
|
+ int tmp = indices[i];
|
|
+ indices[i] = indices[ij];
|
|
+ indices[ij] = tmp;
|
|
+ }
|
|
+ }
|
|
+ } else {
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
|
|
+ int tmp = indices[i];
|
|
+ indices[i] = indices[ij];
|
|
+ indices[ij] = tmp;
|
|
+ }
|
|
+ } else {
|
|
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
|
|
+ int tmp = indices[i];
|
|
+ indices[i] = indices[ij];
|
|
+ indices[ij] = tmp;
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ __syncthreads();
|
|
+ }
|
|
+ }
|
|
+
|
|
+ // Write sorted indices to output, only threads handling valid data
|
|
+ for (int i = tid; i < ncols; i += blockDim.x) {
|
|
+ dst[row * ncols + i] = indices[i];
|
|
+ }
|
|
+}
|
|
+
|
|
+static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
|
+ // Bitonic sort requires ncols to be power of 2
|
|
+ const int ncols_pad = next_power_of_2(ncols);
|
|
+
|
|
+ // Ensure thread count doesn't exceed maximum (typically 1024)
|
|
+ const int max_threads = 1024; // This is the typical max for most GPUs
|
|
+ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
|
|
+
|
|
+ const dim3 block_dims(threads_per_block, 1, 1);
|
|
+ const dim3 block_nums(1, nrows, 1);
|
|
+ const size_t shared_mem = ncols_pad * sizeof(int);
|
|
+
|
|
+ // Check if shared memory size is within limits
|
|
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
|
+
|
|
+ // Instead of logging an error, use GGML_ASSERT with a descriptive message
|
|
+ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
|
|
+
|
|
+ // Launch kernels with the updated thread configuration
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
+ k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
|
+ } else if (order == GGML_SORT_ORDER_DESC) {
|
|
+ k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
|
+ } else {
|
|
+ GGML_ABORT("fatal error");
|
|
+ }
|
|
+}
|
|
+
|
|
+
|
|
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const float * src0_d = (const float *)src0->data;
|
|
float * dst_d = (float *)dst->data;
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
@@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
|
|
-#ifdef GGML_CUDA_USE_CUB
|
|
- const int ncols_pad = next_power_of_2(ncols);
|
|
- const size_t shared_mem = ncols_pad * sizeof(int);
|
|
- const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
|
-
|
|
- if (shared_mem > max_shared_mem || ncols > 1024) {
|
|
- ggml_cuda_pool & pool = ctx.pool();
|
|
- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
+ if (src0->type == GGML_TYPE_I32) {
|
|
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
|
} else {
|
|
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
- }
|
|
+#ifdef GGML_CUDA_USE_CUB
|
|
+ const int ncols_pad = next_power_of_2(ncols);
|
|
+ const size_t shared_mem = ncols_pad * sizeof(int);
|
|
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
|
+
|
|
+ if (shared_mem > max_shared_mem || ncols > 1024) {
|
|
+ ggml_cuda_pool & pool = ctx.pool();
|
|
+ argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
+ } else {
|
|
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
+ }
|
|
#else
|
|
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
|
#endif
|
|
+ }
|
|
}
|
|
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
|
|
index 7697c292d..00d773dd3 100644
|
|
--- a/ggml/src/ggml-cuda/cpy-utils.cuh
|
|
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh
|
|
@@ -215,3 +215,9 @@ template<typename src_t, typename dst_t>
|
|
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
|
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
|
}
|
|
+
|
|
+static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
|
|
+ const int32_t * src = (const int32_t *)cxi;
|
|
+ int32_t * dst = (int32_t *)cdsti;
|
|
+ *dst = *src;
|
|
+}
|
|
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
|
index c4ceb4fc5..0e53ecc39 100644
|
|
--- a/ggml/src/ggml-cuda/cpy.cu
|
|
+++ b/ggml/src/ggml-cuda/cpy.cu
|
|
@@ -352,6 +352,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
}
|
|
|
|
+template <cpy_kernel_t cpy_1>
|
|
+static __global__ void cpy_i32_i32(
|
|
+ const char *cx, char *cdst, const int ne,
|
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
|
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
|
+
|
|
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
|
+
|
|
+ if (i >= ne) {
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
|
|
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
|
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
|
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
|
+ const int64_t x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
|
+
|
|
+ const int64_t i13 = i / (ne10 * ne11 * ne12);
|
|
+ const int64_t i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
|
+ const int64_t i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
|
+ const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
|
+ const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
|
+
|
|
+ cpy_1(cx + x_offset, cdst + dst_offset);
|
|
+}
|
|
+
|
|
+static void ggml_cpy_i32_i32_cuda(
|
|
+ const char * cx, char * cdst, const int ne,
|
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
|
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
|
+
|
|
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
|
|
+}
|
|
+
|
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
|
const int64_t ne = ggml_nelements(src0);
|
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
|
@@ -481,6 +518,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
ggml_cpy_scalar_cuda<half, float>
|
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
}
|
|
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
|
+ // TODO consider converting to template
|
|
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
if (can_be_transposed) {
|
|
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
|
index 51bcbae30..236838e9e 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
|
@@ -4954,8 +4954,77 @@ kernel void kernel_argsort_f32_i32(
|
|
}
|
|
}
|
|
|
|
+typedef void (i32_argsort_t)(
|
|
+ constant ggml_metal_kargs_argsort & args,
|
|
+ device const int32_t * src0,
|
|
+ device int32_t * dst,
|
|
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
+ ushort3 ntg[[threads_per_threadgroup]]);
|
|
+
|
|
+template<ggml_sort_order order>
|
|
+kernel void kernel_argsort_i32_i32(
|
|
+ constant ggml_metal_kargs_argsort & args,
|
|
+ device const int32_t * src0,
|
|
+ device int32_t * dst,
|
|
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
+ // bitonic sort
|
|
+ const int col = tpitg[0];
|
|
+
|
|
+ const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
|
+ const int i01 = tgpig[0]%args.ne01;
|
|
+ const int i02 = tgpig[1];
|
|
+ const int i03 = tgpig[2];
|
|
+
|
|
+ device const int32_t * src0_row = (device const int32_t *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
|
+
|
|
+ // initialize indices
|
|
+ shmem_i32[col] = i00 + col;
|
|
+
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+
|
|
+ for (int k = 2; k <= ntg.x; k *= 2) {
|
|
+ for (int j = k / 2; j > 0; j /= 2) {
|
|
+ int ixj = col ^ j;
|
|
+ if (ixj > col) {
|
|
+ if ((col & k) == 0) {
|
|
+ if (shmem_i32[col] >= args.ne00 ||
|
|
+ (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
+ src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
|
+ src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
|
+ ) {
|
|
+ SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
+ }
|
|
+ } else {
|
|
+ if (shmem_i32[ixj] >= args.ne00 ||
|
|
+ (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
+ src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
|
+ src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
|
+ ) {
|
|
+ SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ // copy the result to dst without the padding
|
|
+ if (i00 + col < args.ne00) {
|
|
+ dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
|
+
|
|
+ dst[col] = shmem_i32[col];
|
|
+ }
|
|
+}
|
|
+
|
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
+template [[host_name("kernel_argsort_i32_i32_asc")]] kernel i32_argsort_t kernel_argsort_i32_i32<GGML_SORT_ORDER_ASC>;
|
|
+template [[host_name("kernel_argsort_i32_i32_desc")]] kernel i32_argsort_t kernel_argsort_i32_i32<GGML_SORT_ORDER_DESC>;
|
|
|
|
typedef void (argsort_merge_t)(
|
|
constant ggml_metal_kargs_argsort_merge & args,
|
|
@@ -5110,8 +5179,154 @@ kernel void kernel_argsort_merge_f32_i32(
|
|
}
|
|
}
|
|
|
|
+template<ggml_sort_order order>
|
|
+kernel void kernel_argsort_merge_i32_i32(
|
|
+ constant ggml_metal_kargs_argsort_merge & args,
|
|
+ device const char * src0,
|
|
+ device const int32_t * tmp,
|
|
+ device int32_t * dst,
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
+
|
|
+ const int im = tgpig[0] / args.ne01;
|
|
+ const int i01 = tgpig[0] % args.ne01;
|
|
+ const int i02 = tgpig[1];
|
|
+ const int i03 = tgpig[2];
|
|
+
|
|
+ const int start = im * (2 * args.len);
|
|
+
|
|
+ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
|
+ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
|
+
|
|
+ const int total = len0 + len1;
|
|
+
|
|
+ device const int32_t * tmp0 = tmp + start
|
|
+ + i01*args.ne0
|
|
+ + i02*args.ne0*args.ne01
|
|
+ + i03*args.ne0*args.ne01*args.ne02;
|
|
+
|
|
+ device const int32_t * tmp1 = tmp0 + args.len;
|
|
+
|
|
+ dst += start
|
|
+ + i01*args.top_k
|
|
+ + i02*args.top_k*args.ne01
|
|
+ + i03*args.top_k*args.ne01*args.ne02;
|
|
+
|
|
+ device const int32_t * src0_row = (device const int32_t *)(src0
|
|
+ + args.nb01*i01
|
|
+ + args.nb02*i02
|
|
+ + args.nb03*i03);
|
|
+
|
|
+ if (total == 0) {
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ const int chunk = (total + ntg.x - 1) / ntg.x;
|
|
+
|
|
+ const int k0 = tpitg.x * chunk;
|
|
+ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
|
+
|
|
+ if (k0 >= args.top_k) {
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ if (k0 >= total) {
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ int low = k0 > len1 ? k0 - len1 : 0;
|
|
+ int high = MIN(k0, len0);
|
|
+
|
|
+ // binary-search partition (i, j) such that i + j = k
|
|
+ while (low < high) {
|
|
+ const int mid = (low + high) >> 1;
|
|
+
|
|
+ const int32_t idx0 = tmp0[mid];
|
|
+ const int32_t idx1 = tmp1[k0 - mid - 1];
|
|
+
|
|
+ const int32_t val0 = src0_row[idx0];
|
|
+ const int32_t val1 = src0_row[idx1];
|
|
+
|
|
+ bool take_left;
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
+ take_left = (val0 <= val1);
|
|
+ } else {
|
|
+ take_left = (val0 >= val1);
|
|
+ }
|
|
+
|
|
+ if (take_left) {
|
|
+ low = mid + 1;
|
|
+ } else {
|
|
+ high = mid;
|
|
+ }
|
|
+ }
|
|
+
|
|
+ int i = low;
|
|
+ int j = k0 - i;
|
|
+
|
|
+ // keep the merge fronts into registers
|
|
+ int32_t idx0 = 0;
|
|
+ int32_t val0 = 0.0f;
|
|
+ if (i < len0) {
|
|
+ idx0 = tmp0[i];
|
|
+ val0 = src0_row[idx0];
|
|
+ }
|
|
+
|
|
+ int32_t idx1 = 0;
|
|
+ int32_t val1 = 0.0f;
|
|
+ if (j < len1) {
|
|
+ idx1 = tmp1[j];
|
|
+ val1 = src0_row[idx1];
|
|
+ }
|
|
+
|
|
+ for (int k = k0; k < k1; ++k) {
|
|
+ int32_t out_idx;
|
|
+
|
|
+ if (i >= len0) {
|
|
+ while (k < k1) {
|
|
+ dst[k++] = tmp1[j++];
|
|
+ }
|
|
+ break;
|
|
+ } else if (j >= len1) {
|
|
+ while (k < k1) {
|
|
+ dst[k++] = tmp0[i++];
|
|
+ }
|
|
+ break;
|
|
+ } else {
|
|
+ bool take_left;
|
|
+
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
+ take_left = (val0 <= val1);
|
|
+ } else {
|
|
+ take_left = (val0 >= val1);
|
|
+ }
|
|
+
|
|
+ if (take_left) {
|
|
+ out_idx = idx0;
|
|
+ ++i;
|
|
+ if (i < len0) {
|
|
+ idx0 = tmp0[i];
|
|
+ val0 = src0_row[idx0];
|
|
+ }
|
|
+ } else {
|
|
+ out_idx = idx1;
|
|
+ ++j;
|
|
+ if (j < len1) {
|
|
+ idx1 = tmp1[j];
|
|
+ val1 = src0_row[idx1];
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+
|
|
+ dst[k] = out_idx;
|
|
+ }
|
|
+}
|
|
+
|
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
+template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_ASC>;
|
|
+template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_DESC>;
|
|
|
|
kernel void kernel_leaky_relu_f32(
|
|
constant ggml_metal_kargs_leaky_relu & args,
|