vulkan: fix top_k bug when there are ties in the input (llama/17659)
* vulkan: Reduce temporary memory usage for TOP_K - Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB. * vulkan: fix top_k bug when there are ties in the input I noticed by inspection a bug in the vulkan top_k shader where if the least value in the top_k appears multiple times we could end up writing those extra copies out rather than some larger values (if the larger values are on higher numbered threads). I rewrote the test verification to handle this case, where the final index set is not necessarily the same. * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
0b53759b29
commit
0484147ab2
|
|
@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||||
sizeof(int) * device->subgroup_size +
|
sizeof(int) * device->subgroup_size +
|
||||||
2 * sizeof(int) +
|
2 * sizeof(int) +
|
||||||
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||||
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
|
||||||
shared int sh_min_idx;
|
shared int sh_min_idx;
|
||||||
shared uint sh_total;
|
shared uint sh_total;
|
||||||
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
|
||||||
// Map float values to uint such that comparisons still work.
|
// Map float values to uint such that comparisons still work.
|
||||||
// Positive values set the high bit, negative values are inverted.
|
// Positive values set the high bit, negative values are inverted.
|
||||||
|
|
@ -156,25 +157,66 @@ void topk(const uint row) {
|
||||||
// We need to compact these values to the start of the dst_row array.
|
// We need to compact these values to the start of the dst_row array.
|
||||||
// Have each subgroup count how many items it'll store, so other
|
// Have each subgroup count how many items it'll store, so other
|
||||||
// subgroups can compute their base offset.
|
// subgroups can compute their base offset.
|
||||||
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
// Values strictly greater than range_min must be stored. For values equal
|
||||||
uvec4 b = subgroupBallot(top);
|
// to range_min, there can be ties and it's possible we'll need to store
|
||||||
uint bit_count = subgroupBallotBitCount(b);
|
// an arbitrary subset of them.
|
||||||
if ((tid % SUBGROUP_SIZE) == 0) {
|
// If total == p.k, have a fast path where we don't need to handle ties.
|
||||||
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
if (total == p.k) {
|
||||||
}
|
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||||
barrier();
|
uvec4 b = subgroupBallot(top);
|
||||||
|
uint bit_count = subgroupBallotBitCount(b);
|
||||||
uint out_idx = 0;
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||||
if (i < tid / SUBGROUP_SIZE) {
|
|
||||||
out_idx += offset_partials[i];
|
|
||||||
}
|
}
|
||||||
}
|
barrier();
|
||||||
|
|
||||||
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
uint out_idx = 0;
|
||||||
if (top) {
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
// TODO: Copy directly to the output?
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
dst_row[out_idx + bit_count_ex] = v;
|
out_idx += offset_partials[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex] = v;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
|
||||||
|
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
|
||||||
|
uvec4 b_top = subgroupBallot(top);
|
||||||
|
uvec4 b_eq_min = subgroupBallot(eq_min);
|
||||||
|
uint bit_count_top = subgroupBallotBitCount(b_top);
|
||||||
|
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
|
||||||
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
|
||||||
|
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
uint eq_min_base = 0;
|
||||||
|
uint eq_min_idx = 0;
|
||||||
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
|
out_idx += offset_partials[i];
|
||||||
|
eq_min_idx += eq_min_partials[i];
|
||||||
|
}
|
||||||
|
eq_min_base += offset_partials[i];
|
||||||
|
}
|
||||||
|
// range_min values are stored at the end
|
||||||
|
eq_min_idx += eq_min_base;
|
||||||
|
|
||||||
|
uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
|
||||||
|
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex_top] = v;
|
||||||
|
}
|
||||||
|
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
|
||||||
|
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue