diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f57199b1..4a6ec1be 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + sizeof(int) * device->subgroup_size + 2 * sizeof(int) + - (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int); if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index f794285e..0b757f38 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE]; shared int sh_min_idx; shared uint sh_total; shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; +shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE]; // Map float values to uint such that comparisons still work. // Positive values set the high bit, negative values are inverted. @@ -156,25 +157,66 @@ void topk(const uint row) { // We need to compact these values to the start of the dst_row array. // Have each subgroup count how many items it'll store, so other // subgroups can compute their base offset. - bool top = f2ui(intBitsToFloat(v.y)) >= range_min; - uvec4 b = subgroupBallot(top); - uint bit_count = subgroupBallotBitCount(b); - if ((tid % SUBGROUP_SIZE) == 0) { - offset_partials[tid / SUBGROUP_SIZE] = bit_count; - } - barrier(); - - uint out_idx = 0; - [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { - if (i < tid / SUBGROUP_SIZE) { - out_idx += offset_partials[i]; + // Values strictly greater than range_min must be stored. For values equal + // to range_min, there can be ties and it's possible we'll need to store + // an arbitrary subset of them. + // If total == p.k, have a fast path where we don't need to handle ties. + if (total == p.k) { + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; } - } + barrier(); - uint bit_count_ex = subgroupBallotExclusiveBitCount(b); - if (top) { - // TODO: Copy directly to the output? - dst_row[out_idx + bit_count_ex] = v; + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + } else { + bool top = f2ui(intBitsToFloat(v.y)) > range_min; + bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min; + uvec4 b_top = subgroupBallot(top); + uvec4 b_eq_min = subgroupBallot(eq_min); + uint bit_count_top = subgroupBallotBitCount(b_top); + uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count_top; + eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min; + } + barrier(); + + uint out_idx = 0; + uint eq_min_base = 0; + uint eq_min_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + eq_min_idx += eq_min_partials[i]; + } + eq_min_base += offset_partials[i]; + } + // range_min values are stored at the end + eq_min_idx += eq_min_base; + + uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top); + uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex_top] = v; + } + if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) { + dst_row[eq_min_idx + bit_count_ex_eq_min] = v; + } } barrier();